Skip to content

Commit

Permalink
Internal: Mutualized timeout code (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Nov 16, 2023
1 parent c59949a commit 8f50e10
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 53 deletions.
10 changes: 3 additions & 7 deletions src/easynetwork/api_async/client/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
__all__ = ["AbstractAsyncNetworkClient"]

import math
import time
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Generic, Self

from ..._typevars import _ReceivedPacketT, _SentPacketT
from ...lowlevel import _utils
from ...lowlevel.socket import SocketAddress

if TYPE_CHECKING:
Expand Down Expand Up @@ -220,20 +220,16 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter
if timeout is None:
timeout = math.inf

perf_counter = time.perf_counter
timeout_after = self.get_backend().timeout

while True:
try:
with timeout_after(timeout):
_start = perf_counter()
with timeout_after(timeout), _utils.ElapsedTime() as elapsed:
packet = await self.recv_packet()
_end = perf_counter()
except OSError:
return
yield packet
timeout -= _end - _start
timeout = max(timeout, 0)
timeout = elapsed.recompute_timeout(timeout)

@abstractmethod
def get_backend(self) -> AsyncBackend:
Expand Down
12 changes: 4 additions & 8 deletions src/easynetwork/api_sync/client/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

__all__ = ["AbstractNetworkClient"]

import time
from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING, Generic, Self

from ..._typevars import _ReceivedPacketT, _SentPacketT
from ...lowlevel import _utils
from ...lowlevel.socket import SocketAddress

if TYPE_CHECKING:
Expand Down Expand Up @@ -174,19 +174,15 @@ def iter_received_packets(self, *, timeout: float | None = 0) -> Iterator[_Recei
Yields:
the received packet.
"""
perf_counter = time.perf_counter

while True:
try:
_start = perf_counter()
packet = self.recv_packet(timeout=timeout)
_end = perf_counter()
with _utils.ElapsedTime() as elapsed:
packet = self.recv_packet(timeout=timeout)
except OSError:
return
yield packet
if timeout is not None:
timeout -= _end - _start
timeout = max(timeout, 0)
timeout = elapsed.recompute_timeout(timeout)

@abstractmethod
def fileno(self) -> int:
Expand Down
8 changes: 4 additions & 4 deletions src/easynetwork/api_sync/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import concurrent.futures
import contextlib
import threading as _threading
import time
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, NoReturn

Expand Down Expand Up @@ -84,11 +83,12 @@ def shutdown(self, timeout: float | None = None) -> None:
if timeout is None:
portal.run_coroutine(self.__server.shutdown)
else:
_start = time.perf_counter()
elapsed = _utils.ElapsedTime()
try:
portal.run_coroutine(self.__do_shutdown_with_timeout, timeout)
with elapsed:
portal.run_coroutine(self.__do_shutdown_with_timeout, timeout)
finally:
timeout -= time.perf_counter() - _start
timeout = elapsed.recompute_timeout(timeout)
self.__is_shutdown.wait(timeout)

async def __do_shutdown_with_timeout(self, timeout_delay: float) -> None:
Expand Down
9 changes: 4 additions & 5 deletions src/easynetwork/api_sync/server/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
]

import threading as _threading
import time

from ...lowlevel import _utils
from .abc import AbstractNetworkServer


Expand Down Expand Up @@ -50,9 +50,8 @@ def run(self) -> None:
self.__is_up_event.set()

def join(self, timeout: float | None = None) -> None:
_start = time.perf_counter()
self.__server.shutdown(timeout=timeout)
_end = time.perf_counter()
with _utils.ElapsedTime() as elapsed:
self.__server.shutdown(timeout=timeout)
if timeout is not None:
timeout -= _end - _start
timeout = elapsed.recompute_timeout(timeout)
super().join(timeout=timeout)
51 changes: 43 additions & 8 deletions src/easynetwork/lowlevel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

__all__ = [
"ElapsedTime",
"check_real_socket_state",
"check_socket_family",
"check_socket_no_ssl",
Expand All @@ -41,7 +42,7 @@
import threading
import time
from collections.abc import Callable, Iterable, Iterator
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeGuard, TypeVar
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeGuard, TypeVar

try:
import ssl as _ssl
Expand Down Expand Up @@ -228,6 +229,43 @@ def remove_traceback_frames_in_place(exc: _ExcType, n: int) -> _ExcType:
return exc.with_traceback(tb)


class ElapsedTime:
__slots__ = ("_current_time_func", "_start_time", "_end_time")

def __init__(self) -> None:
self._current_time_func: Callable[[], float] = time.perf_counter
self._start_time: float | None = None
self._end_time: float | None = None

def __enter__(self) -> Self:
if self._start_time is not None:
raise RuntimeError("Already entered")
self._start_time = self._current_time_func()
return self

def __exit__(self, *args: Any) -> None:
end_time = self._current_time_func()
if self._end_time is not None:
raise RuntimeError("Already exited")
self._end_time = end_time

def get_elapsed(self) -> float:
start_time = self._start_time
if start_time is None:
raise RuntimeError("Not entered")
end_time = self._end_time
if end_time is None:
raise RuntimeError("Within context")
return end_time - start_time

def recompute_timeout(self, old_timeout: float) -> float:
elapsed_time = self.get_elapsed()
new_timeout = old_timeout - elapsed_time
if new_timeout < 0.0:
new_timeout = 0.0
return new_timeout


@contextlib.contextmanager
def lock_with_timeout(
lock: threading.RLock | threading.Lock,
Expand All @@ -238,19 +276,16 @@ def lock_with_timeout(
yield timeout
return
timeout = validate_timeout_delay(timeout, positive_check=True)
perf_counter = time.perf_counter
with contextlib.ExitStack() as stack:
# Try to acquire without blocking first
if lock.acquire(blocking=False):
stack.push(lock)
else:
_start = perf_counter()
if timeout == 0 or not lock.acquire(True, timeout):
raise error_from_errno(_errno.ETIMEDOUT)
with ElapsedTime() as elapsed:
if timeout == 0 or not lock.acquire(True, timeout):
raise error_from_errno(_errno.ETIMEDOUT)
stack.push(lock)
_end = perf_counter()
timeout -= _end - _start
timeout = max(timeout, 0.0)
timeout = elapsed.recompute_timeout(timeout)
yield timeout


Expand Down
10 changes: 3 additions & 7 deletions src/easynetwork/lowlevel/api_sync/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import errno as _errno
import math
import time
from collections.abc import Callable, Mapping
from typing import Any, Generic, TypeGuard

Expand Down Expand Up @@ -192,12 +191,10 @@ def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT:
raise EOFError("end-of-stream")

bufsize: int = self.__max_recv_size
perf_counter = time.perf_counter # pull function to local namespace

while True:
_start = perf_counter()
chunk: bytes = transport.recv(bufsize, timeout)
_end = perf_counter()
with _utils.ElapsedTime() as elapsed:
chunk: bytes = transport.recv(bufsize, timeout)
if not chunk:
self.__eof_reached = True
raise EOFError("end-of-stream")
Expand All @@ -211,8 +208,7 @@ def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT:
return next(consumer)
except StopIteration:
if timeout > 0:
timeout -= _end - _start
timeout = max(timeout, 0.0)
timeout = elapsed.recompute_timeout(timeout)
elif buffer_not_full:
break
# Loop break
Expand Down
11 changes: 3 additions & 8 deletions src/easynetwork/lowlevel/api_sync/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
"StreamWriteTransport",
]

import time
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable

from ... import typed_attr
from ... import _utils, typed_attr


class BaseTransport(typed_attr.TypedAttributeProvider, metaclass=ABCMeta):
Expand Down Expand Up @@ -129,7 +128,6 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None
TimeoutError: Operation timed out.
"""

perf_counter = time.perf_counter # pull function to local namespace
total_sent: int = 0
with memoryview(data) as data:
nb_bytes_to_send = len(data)
Expand All @@ -139,15 +137,12 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None
raise RuntimeError("transport.send() returned a negative value")
return
while total_sent < nb_bytes_to_send:
with data[total_sent:] as buffer:
_start = perf_counter()
with data[total_sent:] as buffer, _utils.ElapsedTime() as elapsed:
sent = self.send(buffer, timeout)
_end = perf_counter()
if sent < 0:
raise RuntimeError("transport.send() returned a negative value")
total_sent += sent
timeout -= _end - _start
timeout = max(timeout, 0.0)
timeout = elapsed.recompute_timeout(timeout)

def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None:
"""
Expand Down
9 changes: 3 additions & 6 deletions src/easynetwork/lowlevel/api_sync/transports/base_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import errno as _errno
import math
import selectors
import time
from abc import abstractmethod
from collections.abc import Callable
from typing import TypeVar
Expand Down Expand Up @@ -99,7 +98,6 @@ def _retry(
callback: Callable[[], _R],
timeout: float,
) -> _R:
perf_counter = time.perf_counter # pull function to local namespace
timeout = _utils.validate_timeout_delay(timeout, positive_check=True)
retry_interval = self._retry_interval
event: int
Expand Down Expand Up @@ -133,10 +131,9 @@ def _retry(
if not available:
raise RuntimeError("timeout error with infinite timeout")
else:
_start = perf_counter()
available = bool(selector.select(wait_time))
_end = perf_counter()
timeout -= _end - _start
with _utils.ElapsedTime() as elapsed:
available = bool(selector.select(wait_time))
timeout = elapsed.recompute_timeout(timeout)
if not available:
if not is_retry_interval:
break
Expand Down
53 changes: 53 additions & 0 deletions tests/unit_test/test_tools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from easynetwork.exceptions import BusyResourceError
from easynetwork.lowlevel._utils import (
ElapsedTime,
ResourceGuard,
check_real_socket_state,
check_socket_family,
Expand Down Expand Up @@ -557,6 +558,58 @@ def func() -> None:
assert len(list(traceback.walk_tb(exception.__traceback__))) == 0


def test____ElapsedTime____catch_elapsed_time(mocker: MockerFixture) -> None:
# Arrange
now: float = 798546132.0
mocker.patch("time.perf_counter", autospec=True, side_effect=[now, now + 12.0])

# Act
with ElapsedTime() as elapsed:
pass

# Assert
assert elapsed.get_elapsed() == pytest.approx(12.0)
assert elapsed.recompute_timeout(42.4) == pytest.approx(30.4)
assert elapsed.recompute_timeout(8.0) == 0.0


def test____ElapsedTime____not_reentrant() -> None:
# Arrange
with ElapsedTime() as elapsed:
# Act & Assert
with pytest.raises(RuntimeError, match=r"^Already entered$"):
with elapsed:
pytest.fail("Should not enter")


def test____ElapsedTime____double_exit() -> None:
# Arrange

# Act & Assert
with pytest.raises(RuntimeError, match=r"^Already exited$"):
with contextlib.ExitStack() as stack:
elapsed = stack.enter_context(ElapsedTime())
stack.push(elapsed)


def test____ElapsedTime____get_elapsed____not_entered() -> None:
# Arrange
elapsed = ElapsedTime()

# Act & Assert
with pytest.raises(RuntimeError, match=r"^Not entered$"):
elapsed.get_elapsed()


def test____ElapsedTime____get_elapsed____within_context() -> None:
# Arrange

# Act & Assert
with ElapsedTime() as elapsed:
with pytest.raises(RuntimeError, match=r"^Within context$"):
elapsed.get_elapsed()


def test____lock_with_timeout____acquire_and_release_with_timeout_at_None() -> None:
# Arrange
lock = threading.Lock()
Expand Down

0 comments on commit 8f50e10

Please sign in to comment.