From 0cb93e77da570b877792c6a9cefb8b873f9c0dab Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 31 Dec 2023 17:24:09 +0100 Subject: [PATCH] [FIX] Synchronous clients' threading.Lock were replaced by threading.RLock in #135 --- src/easynetwork/api_sync/client/tcp.py | 7 ++++--- src/easynetwork/api_sync/client/udp.py | 5 +++-- tests/unit_test/_utils.py | 27 ++++++++++++++++++++++++-- tests/unit_test/test_sync/conftest.py | 4 ++-- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/easynetwork/api_sync/client/tcp.py b/src/easynetwork/api_sync/client/tcp.py index 0acdd7b8..d9715fde 100644 --- a/src/easynetwork/api_sync/client/tcp.py +++ b/src/easynetwork/api_sync/client/tcp.py @@ -21,6 +21,7 @@ import contextlib import errno as _errno import socket as _socket +import threading from collections.abc import Iterator from typing import TYPE_CHECKING, Any, final, overload @@ -240,10 +241,10 @@ def __init__( assert ssl_shared_lock is not None # nosec assert_used if ssl and ssl_shared_lock: - self.__send_lock = self.__receive_lock = _lock.ForkSafeLock() + self.__send_lock = self.__receive_lock = _lock.ForkSafeLock(threading.Lock) else: - self.__send_lock = _lock.ForkSafeLock() - self.__receive_lock = _lock.ForkSafeLock() + self.__send_lock = _lock.ForkSafeLock(threading.Lock) + self.__receive_lock = _lock.ForkSafeLock(threading.Lock) try: self.__endpoint = StreamEndpoint(transport, protocol, max_recv_size=max_recv_size) diff --git a/src/easynetwork/api_sync/client/udp.py b/src/easynetwork/api_sync/client/udp.py index ccf32aca..7558dc45 100644 --- a/src/easynetwork/api_sync/client/udp.py +++ b/src/easynetwork/api_sync/client/udp.py @@ -20,6 +20,7 @@ import contextlib import socket as _socket +import threading from collections.abc import Iterator from typing import Any, final, overload @@ -125,8 +126,8 @@ def __init__( socket.close() raise - self.__send_lock = _lock.ForkSafeLock() - self.__receive_lock = _lock.ForkSafeLock() + self.__send_lock = _lock.ForkSafeLock(threading.Lock) + self.__receive_lock = _lock.ForkSafeLock(threading.Lock) try: self.__endpoint: DatagramEndpoint[_T_SentPacket, _T_ReceivedPacket] = DatagramEndpoint(transport, protocol) self.__socket_proxy = SocketProxy(transport.extra(INETSocketAttribute.socket), lock=self.__send_lock.get) diff --git a/tests/unit_test/_utils.py b/tests/unit_test/_utils.py index 7ce0b236..333ec0f8 100644 --- a/tests/unit_test/_utils.py +++ b/tests/unit_test/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import threading from collections.abc import Sequence from socket import AF_INET, AF_INET6, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM from types import TracebackType @@ -10,17 +11,24 @@ class _LockMixin: + _owner: object = None _locked_count: int = 0 + _reentrant: bool = False class WouldBlock(Exception): pass + def _get_requester_id(self) -> object: + raise NotImplementedError + def _acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if self._locked_count > 0: + if self._locked_count > 0 and (not self._reentrant or self._owner != self._get_requester_id()): if not blocking or timeout >= 0: return False raise _LockMixin.WouldBlock(f"{self.__class__.__name__}.acquire() would block") + if self._reentrant and self._owner is None: + self._owner = self._get_requester_id() self._locked_count += 1 return True @@ -28,13 +36,17 @@ def locked(self) -> bool: return self._locked_count > 0 def _release(self) -> None: + if self._reentrant and self._owner != self._get_requester_id(): + raise RuntimeError("release() called on an unacquired lock") assert self._locked_count > 0 self._locked_count -= 1 + if self._reentrant and self._locked_count == 0: + self._owner = None class DummyLock(_LockMixin): """ - Helper class used to mock threading.Lock and threading.RLock classes + Helper class used to mock threading.Lock class. """ def __enter__(self) -> bool: @@ -50,6 +62,17 @@ def release(self) -> None: return self._release() +class DummyRLock(DummyLock): + """ + Helper class used to mock threading.RLock class. + """ + + _reentrant = True + + def _get_requester_id(self) -> object: + return threading.get_ident() + + class AsyncDummyLock(_LockMixin): """ Helper class used to mock asyncio.Lock classes diff --git a/tests/unit_test/test_sync/conftest.py b/tests/unit_test/test_sync/conftest.py index 0daad5f0..97bcdaf6 100644 --- a/tests/unit_test/test_sync/conftest.py +++ b/tests/unit_test/test_sync/conftest.py @@ -10,8 +10,8 @@ @pytest.fixture(autouse=True) def dummy_lock_cls(mocker: MockerFixture) -> Any: - from .._utils import DummyLock + from .._utils import DummyLock, DummyRLock mocker.patch("threading.Lock", new=DummyLock) - mocker.patch("threading.RLock", new=DummyLock) + mocker.patch("threading.RLock", new=DummyRLock) return DummyLock