Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronous clients' threading.Lock were replaced by threading.RLock #228

Merged
merged 1 commit into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/easynetwork/api_sync/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
import errno as _errno
import socket as _socket
import threading
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, final, overload

Expand Down Expand Up @@ -240,10 +241,10 @@ def __init__(
assert ssl_shared_lock is not None # nosec assert_used

if ssl and ssl_shared_lock:
self.__send_lock = self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = self.__receive_lock = _lock.ForkSafeLock(threading.Lock)
else:
self.__send_lock = _lock.ForkSafeLock()
self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = _lock.ForkSafeLock(threading.Lock)
self.__receive_lock = _lock.ForkSafeLock(threading.Lock)

try:
self.__endpoint = StreamEndpoint(transport, protocol, max_recv_size=max_recv_size)
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/api_sync/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import contextlib
import socket as _socket
import threading
from collections.abc import Iterator
from typing import Any, final, overload

Expand Down Expand Up @@ -125,8 +126,8 @@ def __init__(
socket.close()
raise

self.__send_lock = _lock.ForkSafeLock()
self.__receive_lock = _lock.ForkSafeLock()
self.__send_lock = _lock.ForkSafeLock(threading.Lock)
self.__receive_lock = _lock.ForkSafeLock(threading.Lock)
try:
self.__endpoint: DatagramEndpoint[_T_SentPacket, _T_ReceivedPacket] = DatagramEndpoint(transport, protocol)
self.__socket_proxy = SocketProxy(transport.extra(INETSocketAttribute.socket), lock=self.__send_lock.get)
Expand Down
27 changes: 25 additions & 2 deletions tests/unit_test/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import threading
from collections.abc import Sequence
from socket import AF_INET, AF_INET6, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM
from types import TracebackType
Expand All @@ -10,31 +11,42 @@


class _LockMixin:
_owner: object = None
_locked_count: int = 0
_reentrant: bool = False

class WouldBlock(Exception):
pass

def _get_requester_id(self) -> object:
raise NotImplementedError

def _acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if self._locked_count > 0:
if self._locked_count > 0 and (not self._reentrant or self._owner != self._get_requester_id()):
if not blocking or timeout >= 0:
return False
raise _LockMixin.WouldBlock(f"{self.__class__.__name__}.acquire() would block")

if self._reentrant and self._owner is None:
self._owner = self._get_requester_id()
self._locked_count += 1
return True

def locked(self) -> bool:
return self._locked_count > 0

def _release(self) -> None:
if self._reentrant and self._owner != self._get_requester_id():
raise RuntimeError("release() called on an unacquired lock")
assert self._locked_count > 0
self._locked_count -= 1
if self._reentrant and self._locked_count == 0:
self._owner = None


class DummyLock(_LockMixin):
"""
Helper class used to mock threading.Lock and threading.RLock classes
Helper class used to mock threading.Lock class.
"""

def __enter__(self) -> bool:
Expand All @@ -50,6 +62,17 @@ def release(self) -> None:
return self._release()


class DummyRLock(DummyLock):
"""
Helper class used to mock threading.RLock class.
"""

_reentrant = True

def _get_requester_id(self) -> object:
return threading.get_ident()


class AsyncDummyLock(_LockMixin):
"""
Helper class used to mock asyncio.Lock classes
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_test/test_sync/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

@pytest.fixture(autouse=True)
def dummy_lock_cls(mocker: MockerFixture) -> Any:
from .._utils import DummyLock
from .._utils import DummyLock, DummyRLock

mocker.patch("threading.Lock", new=DummyLock)
mocker.patch("threading.RLock", new=DummyLock)
mocker.patch("threading.RLock", new=DummyRLock)
return DummyLock