Skip to content

Commit

Permalink
[FIX] Synchronous clients' threading.Lock were replaced by threading.…
Browse files Browse the repository at this point in the history
…RLock in #135
  • Loading branch information
francis-clairicia committed Dec 31, 2023
1 parent b6e5f23 commit 0cb93e7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/easynetwork/api_sync/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
import errno as _errno
import socket as _socket
import threading
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, final, overload

Expand Down Expand Up @@ -240,10 +241,10 @@ def __init__(
assert ssl_shared_lock is not None # nosec assert_used

if ssl and ssl_shared_lock:
self.__send_lock = self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = self.__receive_lock = _lock.ForkSafeLock(threading.Lock)
else:
self.__send_lock = _lock.ForkSafeLock()
self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = _lock.ForkSafeLock(threading.Lock)
self.__receive_lock = _lock.ForkSafeLock(threading.Lock)

try:
self.__endpoint = StreamEndpoint(transport, protocol, max_recv_size=max_recv_size)
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/api_sync/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import contextlib
import socket as _socket
import threading
from collections.abc import Iterator
from typing import Any, final, overload

Expand Down Expand Up @@ -125,8 +126,8 @@ def __init__(
socket.close()
raise

self.__send_lock = _lock.ForkSafeLock()
self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = _lock.ForkSafeLock(threading.Lock)
self.__receive_lock = _lock.ForkSafeLock(threading.Lock)
try:
self.__endpoint: DatagramEndpoint[_T_SentPacket, _T_ReceivedPacket] = DatagramEndpoint(transport, protocol)
self.__socket_proxy = SocketProxy(transport.extra(INETSocketAttribute.socket), lock=self.__send_lock.get)
Expand Down
27 changes: 25 additions & 2 deletions tests/unit_test/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import threading
from collections.abc import Sequence
from socket import AF_INET, AF_INET6, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM
from types import TracebackType
Expand All @@ -10,31 +11,42 @@


class _LockMixin:
_owner: object = None
_locked_count: int = 0
_reentrant: bool = False

class WouldBlock(Exception):
pass

def _get_requester_id(self) -> object:
raise NotImplementedError

def _acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if self._locked_count > 0:
if self._locked_count > 0 and (not self._reentrant or self._owner != self._get_requester_id()):
if not blocking or timeout >= 0:
return False
raise _LockMixin.WouldBlock(f"{self.__class__.__name__}.acquire() would block")

if self._reentrant and self._owner is None:
self._owner = self._get_requester_id()
self._locked_count += 1
return True

def locked(self) -> bool:
return self._locked_count > 0

def _release(self) -> None:
if self._reentrant and self._owner != self._get_requester_id():
raise RuntimeError("release() called on an unacquired lock")
assert self._locked_count > 0
self._locked_count -= 1
if self._reentrant and self._locked_count == 0:
self._owner = None


class DummyLock(_LockMixin):
"""
Helper class used to mock threading.Lock and threading.RLock classes
Helper class used to mock threading.Lock class.
"""

def __enter__(self) -> bool:
Expand All @@ -50,6 +62,17 @@ def release(self) -> None:
return self._release()


class DummyRLock(DummyLock):
"""
Helper class used to mock threading.RLock class.
"""

_reentrant = True

def _get_requester_id(self) -> object:
return threading.get_ident()


class AsyncDummyLock(_LockMixin):
"""
Helper class used to mock asyncio.Lock classes
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_test/test_sync/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

@pytest.fixture(autouse=True)
def dummy_lock_cls(mocker: MockerFixture) -> Any:
from .._utils import DummyLock
from .._utils import DummyLock, DummyRLock

mocker.patch("threading.Lock", new=DummyLock)
mocker.patch("threading.RLock", new=DummyLock)
mocker.patch("threading.RLock", new=DummyRLock)
return DummyLock

0 comments on commit 0cb93e7

Please sign in to comment.