Skip to content

Commit

Permalink
Low-level API ( AsyncBackend ): Added create_fair_lock() method (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Sep 15, 2024
1 parent ff8b8a7 commit b1f4522
Show file tree
Hide file tree
Showing 19 changed files with 526 additions and 33 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/lowlevel/async/backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ Locks

.. automethod:: AsyncBackend.create_lock

.. automethod:: AsyncBackend.create_fair_lock

.. autoprotocol:: ILock

Events
Expand Down
2 changes: 1 addition & 1 deletion src/easynetwork/clients/async_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
self.__socket_connector_lock: ILock = backend.create_lock()

self.__receive_lock: ILock = backend.create_lock()
self.__send_lock: ILock = backend.create_lock()
self.__send_lock: ILock = backend.create_fair_lock()

self.__expected_recv_size: int = max_recv_size

Expand Down
2 changes: 1 addition & 1 deletion src/easynetwork/clients/async_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
)
self.__socket_connector_lock: ILock = backend.create_lock()
self.__receive_lock: ILock = backend.create_lock()
self.__send_lock: ILock = backend.create_lock()
self.__send_lock: ILock = backend.create_fair_lock()

@staticmethod
async def __create_socket(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ async def create_udp_listeners(
def create_lock(self) -> ILock:
return self.__asyncio.Lock()

def create_fair_lock(self) -> ILock:
# For now, asyncio.Lock is already a fair (and fast) lock.
return self.__asyncio.Lock()

def create_event(self) -> IEvent:
return self.__asyncio.Event()

Expand Down
97 changes: 97 additions & 0 deletions src/easynetwork/lowlevel/api_async/backend/_common/fair_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
"""Fair lock module."""

from __future__ import annotations

__all__ = ["FairLock"]

from collections import deque
from types import TracebackType

from .... import _utils
from ..abc import AsyncBackend, IEvent, ILock


class FairLock:
"""
A Lock object for inter-task synchronization where tasks are guaranteed to acquire the lock in strict
first-come-first-served order. This means that it always goes to the task which has been waiting longest.
"""

def __init__(self, backend: AsyncBackend) -> None:
self._backend: AsyncBackend = backend
self._waiters: deque[IEvent] | None = None
self._locked: bool = False

def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"

@_utils.inherit_doc(ILock)
async def __aenter__(self) -> None:
await self.acquire()

@_utils.inherit_doc(ILock)
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
/,
) -> None:
self.release()

@_utils.inherit_doc(ILock)
async def acquire(self) -> None:
if self._locked or self._waiters:
if self._waiters is None:
self._waiters = deque()

waiter = self._backend.create_event()
self._waiters.append(waiter)
try:
try:
await waiter.wait()
finally:
self._waiters.remove(waiter)
except BaseException:
if not self._locked:
self._wake_up_first()
raise

self._locked = True

@_utils.inherit_doc(ILock)
def release(self) -> None:
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock not acquired")

def _wake_up_first(self) -> None:
if not self._waiters:
return

waiter = self._waiters[0]
waiter.set()

@_utils.inherit_doc(ILock)
def locked(self) -> bool:
return self._locked
46 changes: 46 additions & 0 deletions src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,49 @@ def __get_error_from_cause(
error.__cause__ = exc_value
error.__suppress_context__ = True
return error.with_traceback(None)


class FastFIFOLock:

def __init__(self) -> None:
self._locked: bool = False
self._lot: trio.lowlevel.ParkingLot = trio.lowlevel.ParkingLot()

def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self._locked else "unlocked"
if self._lot:
extra = f"{extra}, waiters:{len(self._lot)}"
return f"<{res[1:-1]} [{extra}]>"

async def __aenter__(self) -> None:
await self.acquire()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
/,
) -> None:
self.release()

async def acquire(self) -> None:
if self._locked or self._lot:
await self._lot.park()
if not self._locked:
raise AssertionError("should be acquired")
else:
self._locked = True

def release(self) -> None:
if self._locked:
if self._lot:
self._lot.unpark(count=1)
else:
self._locked = False
else:
raise RuntimeError("Lock not acquired")

def locked(self) -> bool:
return self._locked
5 changes: 5 additions & 0 deletions src/easynetwork/lowlevel/api_async/backend/_trio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ async def create_udp_listeners(
def create_lock(self) -> ILock:
return self.__trio.Lock()

def create_fair_lock(self) -> ILock:
from ._trio_utils import FastFIFOLock

return FastFIFOLock()

def create_event(self) -> IEvent:
return self.__trio.Event()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from ..... import _utils, socket as socket_tools
from ....transports.abc import AsyncDatagramListener
from ...abc import AsyncBackend, TaskGroup
from ...abc import AsyncBackend, ILock, TaskGroup


@final
Expand All @@ -39,6 +39,7 @@ class TrioDatagramListenerSocketAdapter(AsyncDatagramListener[tuple[Any, ...]]):
"__listener",
"__trsock",
"__serve_guard",
"__send_lock",
)

from .....constants import MAX_DATAGRAM_BUFSIZE
Expand All @@ -53,6 +54,7 @@ def __init__(self, backend: AsyncBackend, sock: trio.socket.SocketType) -> None:
self.__listener: trio.socket.SocketType = sock
self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock)
self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.")
self.__send_lock: ILock = backend.create_fair_lock()

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
Expand Down Expand Up @@ -93,7 +95,8 @@ async def serve(
raise AssertionError("Expected code to be unreachable.")

async def send_to(self, data: bytes | bytearray | memoryview, address: tuple[Any, ...]) -> None:
await self.__listener.sendto(data, address)
async with self.__send_lock:
await self.__listener.sendto(data, address)

def backend(self) -> AsyncBackend:
return self.__backend
Expand Down
16 changes: 15 additions & 1 deletion src/easynetwork/lowlevel/api_async/backend/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,20 @@ def create_lock(self) -> ILock:
"""
raise NotImplementedError

def create_fair_lock(self) -> ILock:
"""
Creates a Lock object for inter-task synchronization where tasks are guaranteed to acquire the lock in strict
first-come-first-served order.
This means that it always goes to the task which has been waiting longest.
Returns:
A new fair Lock.
"""
from ._common.fair_lock import FairLock

return FairLock(self)

@abstractmethod
def create_event(self) -> IEvent:
"""
Expand Down Expand Up @@ -1236,4 +1250,4 @@ def __enter__(self) -> CancelScope:
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None:
self.scope.__exit__(exc_type, exc_val, exc_tb)
if self.scope.cancelled_caught():
raise TimeoutError("timed out")
raise TimeoutError("timed out") from exc_val
26 changes: 14 additions & 12 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,11 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None:
client = client_cache[address]
except KeyError:
client_cache[address] = client = _ClientToken(DatagramClientContext(address, self), _ClientData(backend))
notify = False
else:
notify = True

await client.data.push_datagram(datagram, notify=notify)
nb_datagrams_in_queue = await client.data.push_datagram(datagram)
del datagram

if client.data.state is None:
del datagram
if client.data.state is None and nb_datagrams_in_queue > 0:
client.data.mark_pending()
await self.__client_coroutine(datagram_received_cb, client, task_group, default_context)

Expand All @@ -217,7 +214,7 @@ async def __client_coroutine(
client_data=client.data,
)
finally:
self.__on_task_done(
self.__on_client_coroutine_task_done(
datagram_received_cb=datagram_received_cb,
client=client,
task_group=task_group,
Expand All @@ -231,8 +228,8 @@ async def __client_coroutine_inner_loop(
client_data: _ClientData,
) -> None:
timeout: float | None
datagram: bytes = client_data.pop_datagram_no_wait()
try:
datagram: bytes = client_data.pop_datagram_no_wait()
# Ignore sent timeout here, we already have the datagram.
await anext_without_asyncgen_hook(request_handler_generator)
except StopAsyncIteration:
Expand All @@ -249,9 +246,10 @@ async def __client_coroutine_inner_loop(

del datagram
null_timeout_ctx = contextlib.nullcontext()
backend = client_data.backend
while True:
try:
with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout):
with null_timeout_ctx if timeout is None else backend.timeout(timeout):
datagram = await client_data.pop_datagram()
action = self.__parse_datagram(datagram, self.__protocol)
except BaseException as exc:
Expand All @@ -267,7 +265,7 @@ async def __client_coroutine_inner_loop(
finally:
await request_handler_generator.aclose()

def __on_task_done(
def __on_client_coroutine_task_done(
self,
datagram_received_cb: Callable[
[DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request]
Expand Down Expand Up @@ -372,12 +370,16 @@ def state(self) -> _ClientState | None:
def queue_is_empty(self) -> bool:
return not self._datagram_queue

async def push_datagram(self, datagram: bytes, *, notify: bool) -> None:
async def push_datagram(self, datagram: bytes) -> int:
self._datagram_queue.append(datagram)
if notify:

# Do not need to notify anyone if state is None.
if self.__state is not None:
async with (queue_condition := self._queue_condition):
queue_condition.notify()

return len(self._datagram_queue)

def pop_datagram_no_wait(self) -> bytes:
return self._datagram_queue.popleft()

Expand Down
2 changes: 1 addition & 1 deletion src/easynetwork/servers/async_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def __init__(
) -> None:
self.__client: _stream_server.ConnectedStreamClient[_T_Response] = client
self.__closing: bool = False
self.__send_lock = client.backend().create_lock()
self.__send_lock = client.backend().create_fair_lock()
self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket))
self.__address: SocketAddress = address
self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class ExceptionCaughtDict(TypedDict, total=False):
transport: asyncio.BaseTransport


@pytest.mark.flaky(retries=3, delay=0)
class TestAsyncioBackendBootstrap:
@pytest.fixture(scope="class")
@staticmethod
Expand Down
Loading

0 comments on commit b1f4522

Please sign in to comment.