From b1f45223b270ac7ee1b8b92e4b3dc37a0186a265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sun, 15 Sep 2024 18:20:34 +0200 Subject: [PATCH] Low-level API ( `AsyncBackend` ): Added `create_fair_lock()` method (#348) --- docs/source/api/lowlevel/async/backend.rst | 2 + src/easynetwork/clients/async_tcp.py | 2 +- src/easynetwork/clients/async_udp.py | 2 +- .../api_async/backend/_asyncio/backend.py | 4 + .../api_async/backend/_common/fair_lock.py | 97 +++++++ .../api_async/backend/_trio/_trio_utils.py | 46 +++ .../api_async/backend/_trio/backend.py | 5 + .../backend/_trio/datagram/listener.py | 7 +- .../lowlevel/api_async/backend/abc.py | 16 +- .../lowlevel/api_async/servers/datagram.py | 26 +- src/easynetwork/servers/async_tcp.py | 2 +- .../test_backend/test_asyncio_backend.py | 1 - .../test_async/test_backend/test_fair_lock.py | 262 ++++++++++++++++++ .../test_communication/test_end2end.py | 32 ++- .../test_asyncio_backend/test_backend.py | 7 +- .../test_asyncio_backend/test_stream.py | 2 + .../test_backend/test_backend.py | 14 + .../test_servers/test_datagram.py | 19 +- .../test_trio_backend/test_backend.py | 13 + 19 files changed, 526 insertions(+), 33 deletions(-) create mode 100644 src/easynetwork/lowlevel/api_async/backend/_common/fair_lock.py create mode 100644 tests/functional_test/test_async/test_backend/test_fair_lock.py diff --git a/docs/source/api/lowlevel/async/backend.rst b/docs/source/api/lowlevel/async/backend.rst index f4b7ff29..80526b11 100644 --- a/docs/source/api/lowlevel/async/backend.rst +++ b/docs/source/api/lowlevel/async/backend.rst @@ -215,6 +215,8 @@ Locks .. automethod:: AsyncBackend.create_lock +.. automethod:: AsyncBackend.create_fair_lock + .. autoprotocol:: ILock Events diff --git a/src/easynetwork/clients/async_tcp.py b/src/easynetwork/clients/async_tcp.py index 370d0e9e..9cd2cf73 100644 --- a/src/easynetwork/clients/async_tcp.py +++ b/src/easynetwork/clients/async_tcp.py @@ -260,7 +260,7 @@ def __init__( self.__socket_connector_lock: ILock = backend.create_lock() self.__receive_lock: ILock = backend.create_lock() - self.__send_lock: ILock = backend.create_lock() + self.__send_lock: ILock = backend.create_fair_lock() self.__expected_recv_size: int = max_recv_size diff --git a/src/easynetwork/clients/async_udp.py b/src/easynetwork/clients/async_udp.py index c60421bd..a476fd74 100644 --- a/src/easynetwork/clients/async_udp.py +++ b/src/easynetwork/clients/async_udp.py @@ -147,7 +147,7 @@ def __init__( ) self.__socket_connector_lock: ILock = backend.create_lock() self.__receive_lock: ILock = backend.create_lock() - self.__send_lock: ILock = backend.create_lock() + self.__send_lock: ILock = backend.create_fair_lock() @staticmethod async def __create_socket( diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py index c6c6b2fa..a48b2356 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py @@ -273,6 +273,10 @@ async def create_udp_listeners( def create_lock(self) -> ILock: return self.__asyncio.Lock() + def create_fair_lock(self) -> ILock: + # For now, asyncio.Lock is already a fair (and fast) lock. + return self.__asyncio.Lock() + def create_event(self) -> IEvent: return self.__asyncio.Event() diff --git a/src/easynetwork/lowlevel/api_async/backend/_common/fair_lock.py b/src/easynetwork/lowlevel/api_async/backend/_common/fair_lock.py new file mode 100644 index 00000000..c2ff815d --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_common/fair_lock.py @@ -0,0 +1,97 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Fair lock module.""" + +from __future__ import annotations + +__all__ = ["FairLock"] + +from collections import deque +from types import TracebackType + +from .... import _utils +from ..abc import AsyncBackend, IEvent, ILock + + +class FairLock: + """ + A Lock object for inter-task synchronization where tasks are guaranteed to acquire the lock in strict + first-come-first-served order. This means that it always goes to the task which has been waiting longest. + """ + + def __init__(self, backend: AsyncBackend) -> None: + self._backend: AsyncBackend = backend + self._waiters: deque[IEvent] | None = None + self._locked: bool = False + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + @_utils.inherit_doc(ILock) + async def __aenter__(self) -> None: + await self.acquire() + + @_utils.inherit_doc(ILock) + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + /, + ) -> None: + self.release() + + @_utils.inherit_doc(ILock) + async def acquire(self) -> None: + if self._locked or self._waiters: + if self._waiters is None: + self._waiters = deque() + + waiter = self._backend.create_event() + self._waiters.append(waiter) + try: + try: + await waiter.wait() + finally: + self._waiters.remove(waiter) + except BaseException: + if not self._locked: + self._wake_up_first() + raise + + self._locked = True + + @_utils.inherit_doc(ILock) + def release(self) -> None: + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock not acquired") + + def _wake_up_first(self) -> None: + if not self._waiters: + return + + waiter = self._waiters[0] + waiter.set() + + @_utils.inherit_doc(ILock) + def locked(self) -> bool: + return self._locked diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py b/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py index bbf30d11..a29b3003 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py @@ -75,3 +75,49 @@ def __get_error_from_cause( error.__cause__ = exc_value error.__suppress_context__ = True return error.with_traceback(None) + + +class FastFIFOLock: + + def __init__(self) -> None: + self._locked: bool = False + self._lot: trio.lowlevel.ParkingLot = trio.lowlevel.ParkingLot() + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._lot: + extra = f"{extra}, waiters:{len(self._lot)}" + return f"<{res[1:-1]} [{extra}]>" + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + /, + ) -> None: + self.release() + + async def acquire(self) -> None: + if self._locked or self._lot: + await self._lot.park() + if not self._locked: + raise AssertionError("should be acquired") + else: + self._locked = True + + def release(self) -> None: + if self._locked: + if self._lot: + self._lot.unpark(count=1) + else: + self._locked = False + else: + raise RuntimeError("Lock not acquired") + + def locked(self) -> bool: + return self._locked diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py b/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py index 429654f3..ae9b3ca8 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py @@ -254,6 +254,11 @@ async def create_udp_listeners( def create_lock(self) -> ILock: return self.__trio.Lock() + def create_fair_lock(self) -> ILock: + from ._trio_utils import FastFIFOLock + + return FastFIFOLock() + def create_event(self) -> IEvent: return self.__trio.Event() diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py index 7404c432..ba51c5a7 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py @@ -29,7 +29,7 @@ from ..... import _utils, socket as socket_tools from ....transports.abc import AsyncDatagramListener -from ...abc import AsyncBackend, TaskGroup +from ...abc import AsyncBackend, ILock, TaskGroup @final @@ -39,6 +39,7 @@ class TrioDatagramListenerSocketAdapter(AsyncDatagramListener[tuple[Any, ...]]): "__listener", "__trsock", "__serve_guard", + "__send_lock", ) from .....constants import MAX_DATAGRAM_BUFSIZE @@ -53,6 +54,7 @@ def __init__(self, backend: AsyncBackend, sock: trio.socket.SocketType) -> None: self.__listener: trio.socket.SocketType = sock self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock) self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.") + self.__send_lock: ILock = backend.create_fair_lock() def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -93,7 +95,8 @@ async def serve( raise AssertionError("Expected code to be unreachable.") async def send_to(self, data: bytes | bytearray | memoryview, address: tuple[Any, ...]) -> None: - await self.__listener.sendto(data, address) + async with self.__send_lock: + await self.__listener.sendto(data, address) def backend(self) -> AsyncBackend: return self.__backend diff --git a/src/easynetwork/lowlevel/api_async/backend/abc.py b/src/easynetwork/lowlevel/api_async/backend/abc.py index a889b958..b25cb471 100644 --- a/src/easynetwork/lowlevel/api_async/backend/abc.py +++ b/src/easynetwork/lowlevel/api_async/backend/abc.py @@ -1144,6 +1144,20 @@ def create_lock(self) -> ILock: """ raise NotImplementedError + def create_fair_lock(self) -> ILock: + """ + Creates a Lock object for inter-task synchronization where tasks are guaranteed to acquire the lock in strict + first-come-first-served order. + + This means that it always goes to the task which has been waiting longest. + + Returns: + A new fair Lock. + """ + from ._common.fair_lock import FairLock + + return FairLock(self) + @abstractmethod def create_event(self) -> IEvent: """ @@ -1236,4 +1250,4 @@ def __enter__(self) -> CancelScope: def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: self.scope.__exit__(exc_type, exc_val, exc_tb) if self.scope.cancelled_caught(): - raise TimeoutError("timed out") + raise TimeoutError("timed out") from exc_val diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index d05704fc..add00180 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -186,14 +186,11 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None: client = client_cache[address] except KeyError: client_cache[address] = client = _ClientToken(DatagramClientContext(address, self), _ClientData(backend)) - notify = False - else: - notify = True - await client.data.push_datagram(datagram, notify=notify) + nb_datagrams_in_queue = await client.data.push_datagram(datagram) + del datagram - if client.data.state is None: - del datagram + if client.data.state is None and nb_datagrams_in_queue > 0: client.data.mark_pending() await self.__client_coroutine(datagram_received_cb, client, task_group, default_context) @@ -217,7 +214,7 @@ async def __client_coroutine( client_data=client.data, ) finally: - self.__on_task_done( + self.__on_client_coroutine_task_done( datagram_received_cb=datagram_received_cb, client=client, task_group=task_group, @@ -231,8 +228,8 @@ async def __client_coroutine_inner_loop( client_data: _ClientData, ) -> None: timeout: float | None - datagram: bytes = client_data.pop_datagram_no_wait() try: + datagram: bytes = client_data.pop_datagram_no_wait() # Ignore sent timeout here, we already have the datagram. await anext_without_asyncgen_hook(request_handler_generator) except StopAsyncIteration: @@ -249,9 +246,10 @@ async def __client_coroutine_inner_loop( del datagram null_timeout_ctx = contextlib.nullcontext() + backend = client_data.backend while True: try: - with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout): + with null_timeout_ctx if timeout is None else backend.timeout(timeout): datagram = await client_data.pop_datagram() action = self.__parse_datagram(datagram, self.__protocol) except BaseException as exc: @@ -267,7 +265,7 @@ async def __client_coroutine_inner_loop( finally: await request_handler_generator.aclose() - def __on_task_done( + def __on_client_coroutine_task_done( self, datagram_received_cb: Callable[ [DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request] @@ -372,12 +370,16 @@ def state(self) -> _ClientState | None: def queue_is_empty(self) -> bool: return not self._datagram_queue - async def push_datagram(self, datagram: bytes, *, notify: bool) -> None: + async def push_datagram(self, datagram: bytes) -> int: self._datagram_queue.append(datagram) - if notify: + + # Do not need to notify anyone if state is None. + if self.__state is not None: async with (queue_condition := self._queue_condition): queue_condition.notify() + return len(self._datagram_queue) + def pop_datagram_no_wait(self) -> bytes: return self._datagram_queue.popleft() diff --git a/src/easynetwork/servers/async_tcp.py b/src/easynetwork/servers/async_tcp.py index e910b0e6..4a3652b0 100644 --- a/src/easynetwork/servers/async_tcp.py +++ b/src/easynetwork/servers/async_tcp.py @@ -384,7 +384,7 @@ def __init__( ) -> None: self.__client: _stream_server.ConnectedStreamClient[_T_Response] = client self.__closing: bool = False - self.__send_lock = client.backend().create_lock() + self.__send_lock = client.backend().create_fair_lock() self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket)) self.__address: SocketAddress = address self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index 648148ee..d6128b70 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -32,7 +32,6 @@ class ExceptionCaughtDict(TypedDict, total=False): transport: asyncio.BaseTransport -@pytest.mark.flaky(retries=3, delay=0) class TestAsyncioBackendBootstrap: @pytest.fixture(scope="class") @staticmethod diff --git a/tests/functional_test/test_async/test_backend/test_fair_lock.py b/tests/functional_test/test_async/test_backend/test_fair_lock.py new file mode 100644 index 00000000..a05fe1d9 --- /dev/null +++ b/tests/functional_test/test_async/test_backend/test_fair_lock.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeVarTuple + +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend, IEvent, ILock +from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend + +import pytest +import sniffio + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + + +_T_Args = TypeVarTuple("_T_Args") + + +class TestFairLock: + @pytest.fixture( + scope="class", + autouse=True, + params=[ + pytest.param("asyncio", marks=pytest.mark.asyncio), + pytest.param("trio", marks=pytest.mark.feature_trio(async_test_auto_mark=True)), + ], + ) + @staticmethod + def backend(request: pytest.FixtureRequest) -> AsyncBackend: + return new_builtin_backend(request.param) + + @pytest.fixture(params=["default", "custom"]) + @staticmethod + def fair_lock(request: pytest.FixtureRequest, backend: AsyncBackend) -> ILock: + from easynetwork.lowlevel.api_async.backend._common.fair_lock import FairLock + + match request.param: + case "default": + return FairLock(backend) + case "custom": + lock = backend.create_fair_lock() + if isinstance(lock, FairLock): + pytest.skip("uses default implementation with 'custom' parameter") + return lock + case _: + pytest.fail(f"Invalid param: {request.param!r}") + + @staticmethod + def call_soon(f: Callable[[*_T_Args], Any], *args: *_T_Args) -> None: + match sniffio.current_async_library(): + case "asyncio": + import asyncio + + asyncio.get_running_loop().call_soon(f, *args) + case "trio": + import trio + + trio.lowlevel.current_trio_token().run_sync_soon(f, *args) + case _: + pytest.fail("unknown async library") + + async def test____acquire____fifo_acquire____manual( + self, + fair_lock: ILock, + backend: AsyncBackend, + ) -> None: + results: list[int] = [] + + async def coroutine(index: int) -> None: + await fair_lock.acquire() + try: + results.append(index) + finally: + fair_lock.release() + + assert not fair_lock.locked() + async with backend.create_task_group() as task_group: + + await fair_lock.acquire() + try: + assert fair_lock.locked() + results.append(0) + + for index in range(1, 4): + await task_group.start(coroutine, index) + finally: + fair_lock.release() + + assert not fair_lock.locked() + assert results == [0, 1, 2, 3] + + async def test____acquire____fifo_acquire____context_manager( + self, + fair_lock: ILock, + backend: AsyncBackend, + ) -> None: + results: list[int] = [] + + async def coroutine(index: int) -> None: + async with fair_lock: + results.append(index) + + assert not fair_lock.locked() + async with backend.create_task_group() as task_group: + async with fair_lock: + assert fair_lock.locked() + results.append(0) + + for index in range(1, 4): + await task_group.start(coroutine, index) + + assert not fair_lock.locked() + assert results == [0, 1, 2, 3] + + async def test____acquire____fast_acquire( + self, + fair_lock: ILock, + backend: AsyncBackend, + mocker: MockerFixture, + ) -> None: + other_task = mocker.async_stub() + + assert not fair_lock.locked() + async with backend.create_task_group() as task_group: + task_group.start_soon(other_task) + async with fair_lock: + assert fair_lock.locked() + + other_task.assert_called_once() + other_task.assert_not_awaited() + + async def test____acquire____cancelled( + self, + fair_lock: ILock, + backend: AsyncBackend, + ) -> None: + async with backend.create_task_group() as task_group: + + await fair_lock.acquire() + + acquire_task = await task_group.start(fair_lock.acquire) + + assert acquire_task.cancel() + + with pytest.raises(backend.get_cancelled_exc_class()): + await acquire_task.join() + + assert fair_lock.locked() + fair_lock.release() + assert not fair_lock.locked() + + # Taken from asyncio.Lock unit tests + async def test____acquire____cancel_race( + self, + fair_lock: ILock, + backend: AsyncBackend, + ) -> None: + # Several tasks: + # - A acquires the lock + # - B is blocked in acquire() + # - C is blocked in acquire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + async def lockit(name: str, blocker: IEvent | None) -> None: + await fair_lock.acquire() + try: + if blocker is not None: + await blocker.wait() + finally: + fair_lock.release() + + async with backend.create_task_group() as task_group: + + blocker_a = backend.create_event() + await task_group.start(lockit, "A", blocker_a) + assert fair_lock.locked() + + task_b = await task_group.start(lockit, "B", None) + task_c = await task_group.start(lockit, "C", None) + + # Create the race and check. + # Without the fix this failed at the last assert. + blocker_a.set() + task_b.cancel() + with backend.timeout(0.200): + await task_c.join() + + # Taken from asyncio.Lock unit tests + async def test____acquire____cancel_release_race( + self, + fair_lock: ILock, + backend: AsyncBackend, + ) -> None: + # Acquire 4 locks, cancel second, release first + # and 2 locks are taken at once. + + lock_count: int = 0 + call_count: int = 0 + + async def lockit() -> None: + nonlocal lock_count + nonlocal call_count + call_count += 1 + await fair_lock.acquire() + lock_count += 1 + + async with backend.create_task_group() as task_group: + + await fair_lock.acquire() + + # Start scheduled tasks + t1 = await task_group.start(lockit) + t2 = await task_group.start(lockit) + t3 = await task_group.start(lockit) + + def trigger() -> None: + t1.cancel() + fair_lock.release() + + self.call_soon(trigger) + with pytest.raises(backend.get_cancelled_exc_class()): + # Wait for cancellation + await t1.join() + + # Make sure only one lock was taken + for _ in range(3): + if not lock_count: + await backend.coro_yield() + assert lock_count == 1 + # While 3 calls were made to lockit() + assert call_count == 3 + assert t1.cancelled() and t2.done() + + # Cleanup the task that is stuck on acquire. + t3.cancel() + with backend.move_on_after(0.200): + await t3.wait() + assert t3.cancelled() + + async def test____release____not_acquired( + self, + fair_lock: ILock, + ) -> None: + with pytest.raises(RuntimeError): + fair_lock.release() + + async def test____release____no_waiters( + self, + fair_lock: ILock, + ) -> None: + await fair_lock.acquire() + assert fair_lock.locked() + + fair_lock.release() + assert not fair_lock.locked() diff --git a/tests/functional_test/test_communication/test_end2end.py b/tests/functional_test/test_communication/test_end2end.py index 9947886e..df42bf22 100644 --- a/tests/functional_test/test_communication/test_end2end.py +++ b/tests/functional_test/test_communication/test_end2end.py @@ -76,7 +76,7 @@ def server_address(server: StandaloneTCPNetworkServer[str, str]) -> tuple[str, i port = server.get_addresses()[0].port return ("localhost", port) - def test____blocking_client____echo( + def test____tcp_blocking_client____echo( self, server_address: tuple[str, int], stream_protocol: AnyStreamProtocolType[str, str], @@ -92,10 +92,14 @@ def test____blocking_client____echo( # Several write for i in range(3): client.send_packet(f"Hello world {i}") + responses: list[str] = [] + expected: list[str] = [] for i in range(3): - assert client.recv_packet(timeout=1) == f"Hello world {i}" + responses.append(client.recv_packet(timeout=1)) + expected.append(f"Hello world {i}") + assert responses == expected - async def test____asynchronous_client____echo( + async def test____tcp_asynchronous_client____echo( self, async_client_backend: BuiltinAsyncBackendLiteral, server_address: tuple[str, int], @@ -113,9 +117,13 @@ async def test____asynchronous_client____echo( # Several write for i in range(3): await client.send_packet(f"Hello world {i}") + responses: list[str] = [] + expected: list[str] = [] for i in range(3): with client.backend().timeout(1): - assert (await client.recv_packet()) == f"Hello world {i}" + responses.append(await client.recv_packet()) + expected.append(f"Hello world {i}") + assert responses == expected class TestNetworkUDP(BaseTestNetworkServer): @@ -133,7 +141,7 @@ def server_address(server: StandaloneUDPNetworkServer[str, str]) -> tuple[str, i port = server.get_addresses()[0].port return ("127.0.0.1", port) - def test____blocking_client____echo( + def test____udp_blocking_client____echo( self, server_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], @@ -149,10 +157,14 @@ def test____blocking_client____echo( # Several write for i in range(3): client.send_packet(f"Hello world {i}") + responses: list[str] = [] + expected: list[str] = [] for i in range(3): - assert client.recv_packet(timeout=1) == f"Hello world {i}" + responses.append(client.recv_packet(timeout=1)) + expected.append(f"Hello world {i}") + assert responses == expected - async def test____asynchronous_client____echo( + async def test____udp_asynchronous_client____echo( self, async_client_backend: BuiltinAsyncBackendLiteral, server_address: tuple[str, int], @@ -170,6 +182,10 @@ async def test____asynchronous_client____echo( # Several write for i in range(3): await client.send_packet(f"Hello world {i}") + responses: list[str] = [] + expected: list[str] = [] for i in range(3): with client.backend().timeout(1): - assert (await client.recv_packet()) == f"Hello world {i}" + responses.append(await client.recv_packet()) + expected.append(f"Hello world {i}") + assert responses == expected diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py index af121cc2..1ea45716 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py @@ -755,8 +755,10 @@ async def test____create_udp_listeners____bind_to_all_interfaces( ] assert listener_sockets == [mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4] + @pytest.mark.parametrize("fair_lock", [False, True], ids=lambda p: f"fair_lock=={p}") async def test____create_lock____use_asyncio_Lock_class( self, + fair_lock: bool, backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: @@ -764,7 +766,10 @@ async def test____create_lock____use_asyncio_Lock_class( mock_Lock = mocker.patch("asyncio.Lock", return_value=mocker.sentinel.lock) # Act - lock = backend.create_lock() + if fair_lock: + lock = backend.create_fair_lock() + else: + lock = backend.create_lock() # Assert mock_Lock.assert_called_once_with() diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index 78c05230..7e788d01 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -425,8 +425,10 @@ async def test____accept____busy( self, listener: ListenerSocketAdapter[Any], mock_tcp_listener_socket: MagicMock, + mock_tcp_socket_factory: Callable[[], MagicMock], ) -> None: # Arrange + mock_tcp_listener_socket.accept.return_value = (mock_tcp_socket_factory(), ("127.0.0.1", 12345)) with self._set_sock_method_in_blocking_state(mock_tcp_listener_socket.accept): _ = await self._busy_socket_task(listener.raw_accept(), mock_tcp_listener_socket.accept) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py index 3207b29b..ddc259a5 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py @@ -4,6 +4,7 @@ from collections.abc import Awaitable from typing import TYPE_CHECKING, Any, final, no_type_check +from easynetwork.lowlevel.api_async.backend._common.fair_lock import FairLock from easynetwork.lowlevel.api_async.backend.abc import TaskInfo import pytest @@ -175,3 +176,16 @@ async def test____getnameinfo____run_stdlib_function_in_thread( ), abandon_on_cancel=True, ) + + async def test____create_fair_lock____returns_default_impl( + self, + backend: MockBackend, + ) -> None: + # Arrange + + # Act + lock = backend.create_fair_lock() + + # Assert + assert isinstance(lock, FairLock) + assert lock._backend is backend diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py index 225ce4fc..14602239 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py @@ -286,14 +286,21 @@ async def test____datagram_queue____push_datagram( client_data._queue_condition = queue_condition # Act - await client_data.push_datagram(b"datagram_1", notify=notify) - await client_data.push_datagram(b"datagram_2", notify=notify) - await client_data.push_datagram(b"datagram_3", notify=notify) + n = await client_data.push_datagram(b"datagram_1") + assert n == 1 + if notify: + client_data.mark_pending() + n = await client_data.push_datagram(b"datagram_2") + assert n == 2 + if notify: + client_data.mark_running() + n = await client_data.push_datagram(b"datagram_3") + assert n == 3 # Assert assert list(client_data._datagram_queue) == [b"datagram_1", b"datagram_2", b"datagram_3"] if notify: - assert queue_condition.notify.call_count == 3 + assert queue_condition.notify.call_count == 2 else: queue_condition.notify.assert_not_called() @@ -336,12 +343,14 @@ async def test____datagram_queue____pop_datagram____wait_until_notification( client_data: _ClientData, ) -> None: # Arrange + client_data.mark_pending() + client_data.mark_running() pop_datagram_task = asyncio.create_task(client_data.pop_datagram()) await asyncio.sleep(0.01) assert not pop_datagram_task.done() # Act - await client_data.push_datagram(b"datagram_1", notify=True) + await client_data.push_datagram(b"datagram_1") # Assert assert (await pop_datagram_task) == b"datagram_1" diff --git a/tests/unit_test/test_async/test_trio_backend/test_backend.py b/tests/unit_test/test_async/test_trio_backend/test_backend.py index 0c5c6af6..c76dc13d 100644 --- a/tests/unit_test/test_async/test_trio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_trio_backend/test_backend.py @@ -626,6 +626,19 @@ async def test____create_lock____use_trio_Lock_class( mock_Lock.assert_called_once_with() assert lock is mocker.sentinel.lock + async def test____create_fair_lock____returns_custom_lock( + self, + backend: TrioBackend, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import FastFIFOLock + + # Act + lock = backend.create_fair_lock() + + # Assert + assert isinstance(lock, FastFIFOLock) + async def test____create_event____use_trio_Event_class( self, backend: TrioBackend,