diff --git a/docs/source/api/async/backend.rst b/docs/source/api/async/backend.rst index 84943d1d..24711597 100644 --- a/docs/source/api/async/backend.rst +++ b/docs/source/api/async/backend.rst @@ -72,17 +72,11 @@ Creating Concurrent Tasks .. autoclass:: Task :members: -Spawning System Tasks -""""""""""""""""""""" - -.. automethod:: AsyncBackend.spawn_task - -.. autoclass:: SystemTask - :members: - Timeouts ^^^^^^^^ +.. automethod:: AsyncBackend.open_cancel_scope + .. automethod:: AsyncBackend.move_on_after .. automethod:: AsyncBackend.move_on_at @@ -91,7 +85,7 @@ Timeouts .. automethod:: AsyncBackend.timeout_at -.. autoclass:: TimeoutHandle +.. autoclass:: CancelScope :members: Networking @@ -205,16 +199,6 @@ Backend Factory :exclude-members: GROUP_NAME -Tasks Utilities -=============== - -.. automodule:: easynetwork.api_async.backend.tasks - :no-docstring: - -.. autoclass:: SingleTaskRunner - :members: - - Concurrency And Multithreading (``concurrent.futures`` Integration) =================================================================== diff --git a/src/easynetwork/api_async/backend/abc.py b/src/easynetwork/api_async/backend/abc.py index 724214db..53f77322 100644 --- a/src/easynetwork/api_async/backend/abc.py +++ b/src/easynetwork/api_async/backend/abc.py @@ -23,20 +23,21 @@ "AsyncDatagramSocketAdapter", "AsyncListenerSocketAdapter", "AsyncStreamSocketAdapter", + "CancelScope", "ICondition", "IEvent", "ILock", "Task", "TaskGroup", "ThreadsPortal", - "TimeoutHandle", ] +import contextlib import contextvars import math from abc import ABCMeta, abstractmethod -from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping, Sequence -from contextlib import AbstractAsyncContextManager +from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Mapping, Sequence +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Any, Generic, NoReturn, ParamSpec, Protocol, Self, TypeVar if TYPE_CHECKING: @@ -196,12 +197,12 @@ class Task(Generic[_T_co], metaclass=ABCMeta): @abstractmethod def done(self) -> bool: """ - Returns :data:`True` if the Task is done. + Returns the Task state. A Task is *done* when the wrapped coroutine either returned a value, raised an exception, or the Task was cancelled. Returns: - The Task state. + :data:`True` if the Task is done. """ raise NotImplementedError @@ -225,13 +226,13 @@ def cancel(self) -> bool: @abstractmethod def cancelled(self) -> bool: """ - Returns :data:`True` if the Task is *cancelled*. + Returns the cancellation state. The Task is *cancelled* when the cancellation was requested with :meth:`cancel` and the wrapped coroutine propagated the ``backend.get_cancelled_exc_class()`` exception thrown into it. Returns: - the cancellation state. + :data:`True` if the Task is *cancelled* """ raise NotImplementedError @@ -264,14 +265,6 @@ async def join(self) -> _T_co: """ raise NotImplementedError - -class SystemTask(Task[_T_co]): - """ - A :class:`SystemTask` is a :class:`Task` that runs concurrently with the current root task. - """ - - __slots__ = () - @abstractmethod async def join_or_cancel(self) -> _T_co: """ @@ -292,6 +285,104 @@ async def join_or_cancel(self) -> _T_co: raise NotImplementedError +class CancelScope(metaclass=ABCMeta): + """ + A temporary scope opened by a task that can be used by other tasks to control its execution time. + + Unlike trio's CancelScope, there is no "shielded" scopes; you must use :meth:`AsyncBackend.ignore_cancellation`. + """ + + __slots__ = ("__weakref__",) + + @abstractmethod + def __enter__(self) -> Self: + raise NotImplementedError + + @abstractmethod + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool: + raise NotImplementedError + + @abstractmethod + def cancel(self) -> None: + """ + Request the Task to be cancelled. + + This arranges for a ``backend.get_cancelled_exc_class()`` exception to be thrown into the wrapped coroutine + on the next cycle of the event loop. + + :meth:`CancelScope.cancel` does not guarantee that the Task will be cancelled, + although suppressing cancellation completely is not common and is actively discouraged. + """ + raise NotImplementedError + + @abstractmethod + def cancel_called(self) -> bool: + """ + Checks if :meth:`cancel` has been called. + + Returns: + :data:`True` if :meth:`cancel` has been called. + """ + raise NotImplementedError + + @abstractmethod + def cancelled_caught(self) -> bool: + """ + Returns the scope cancellation state. + + Returns: + :data:`True` if the scope has been is *cancelled*. + """ + raise NotImplementedError + + @abstractmethod + def when(self) -> float: + """ + Returns the current deadline. + + Returns: + the absolute time in seconds. :data:`math.inf` if the current deadline is not set. + """ + raise NotImplementedError + + @abstractmethod + def reschedule(self, when: float, /) -> None: + """ + Reschedules the timeout. + + Parameters: + when: The new deadline. + """ + raise NotImplementedError + + @property + def deadline(self) -> float: + """ + A read-write attribute to simplify the timeout management. + + For example, this statement:: + + scope.deadline += 30 + + is equivalent to:: + + scope.reschedule(scope.when() + 30) + + It is also possible to remove the timeout by deleting the attribute:: + + del scope.deadline + """ + return self.when() + + @deadline.setter + def deadline(self, value: float) -> None: + self.reschedule(value) + + @deadline.deleter + def deadline(self) -> None: + self.reschedule(math.inf) + + class TaskGroup(metaclass=ABCMeta): """ Groups several asynchronous tasks together. @@ -664,74 +755,6 @@ async def connect(self) -> AsyncStreamSocketAdapter: raise NotImplementedError -class TimeoutHandle(metaclass=ABCMeta): - """ - Interface to deal with an actual timeout scope. - - See :meth:`AsyncBackend.move_on_after` for details. - """ - - __slots__ = () - - @abstractmethod - def when(self) -> float: - """ - Returns the current deadline. - - Returns: - the absolute time in seconds. :data:`math.inf` if the current deadline is not set. - A negative value can be returned. - """ - raise NotImplementedError - - @abstractmethod - def reschedule(self, when: float, /) -> None: - """ - Reschedules the timeout. - - Parameters: - when: The new deadline. - """ - raise NotImplementedError - - @abstractmethod - def expired(self) -> bool: - """ - Returns whether the context manager has exceeded its deadline (expired). - - Returns: - the timeout state. - """ - raise NotImplementedError - - @property - def deadline(self) -> float: - """ - A read-write attribute to simplify the timeout management. - - For example, this statement:: - - handle.deadline += 30 - - is equivalent to:: - - handle.reschedule(handle.when() + 30) - - It is also possible to remove the timeout by deleting the attribute:: - - del handle.deadline - """ - return self.when() - - @deadline.setter - def deadline(self, value: float) -> None: - self.reschedule(value) - - @deadline.deleter - def deadline(self) -> None: - self.reschedule(math.inf) - - class AsyncBackend(metaclass=ABCMeta): """ Asynchronous backend interface. @@ -837,9 +860,21 @@ async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, _T_co]) -> _T raise NotImplementedError @abstractmethod - def timeout(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def open_cancel_scope(self, *, deadline: float = ...) -> CancelScope: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Open a new cancel scope. See :meth:`move_on_after` for details. + + Parameters: + deadline: absolute time to stop waiting. Defaults to :data:`math.inf`. + + Returns: + a new cancel scope. + """ + raise NotImplementedError + + def timeout(self, delay: float) -> AbstractContextManager[CancelScope]: + """ + Returns a :term:`context manager` that can be used to limit the amount of time spent waiting on something. This function and :meth:`move_on_after` are similar in that both create a context manager with a given timeout, and if the timeout expires then both will cause ``backend.get_cancelled_exc_class()`` to be raised within the scope. @@ -850,14 +885,13 @@ def timeout(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: delay: number of seconds to wait. Returns: - an :term:`asynchronous context manager` + a :term:`context manager` """ - raise NotImplementedError + return _timeout_after(self, delay) - @abstractmethod - def timeout_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def timeout_at(self, deadline: float) -> AbstractContextManager[CancelScope]: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Returns a :term:`context manager` that can be used to limit the amount of time spent waiting on something. This function and :meth:`move_on_at` are similar in that both create a context manager with a given timeout, and if the timeout expires then both will cause ``backend.get_cancelled_exc_class()`` to be raised within the scope. @@ -868,14 +902,13 @@ def timeout_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHand deadline: absolute time to stop waiting. Returns: - an :term:`asynchronous context manager` + a :term:`context manager` """ - raise NotImplementedError + return _timeout_at(self, deadline) - @abstractmethod - def move_on_after(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def move_on_after(self, delay: float) -> CancelScope: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Returns a new :class:`CancelScope` that can be used to limit the amount of time spent waiting on something. The deadline is set to now + `delay`. Example:: @@ -886,7 +919,7 @@ async def long_running_operation(backend): async def main(): ... - async with backend.move_on_after(10): + with backend.move_on_after(10): await long_running_operation(backend) print("After at most 10 seconds.") @@ -897,15 +930,14 @@ async def main(): Parameters: delay: number of seconds to wait. If `delay` is :data:`math.inf`, no time limit will be applied; this can be useful if the delay is unknown when the context manager is created. - In either case, the context manager can be rescheduled after creation using :meth:`TimeoutHandle.reschedule`. + In either case, the context manager can be rescheduled after creation using :meth:`CancelScope.reschedule`. Returns: - an :term:`asynchronous context manager` + a new cancel scope. """ - raise NotImplementedError + return self.open_cancel_scope(deadline=self.current_time() + delay) - @abstractmethod - def move_on_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def move_on_at(self, deadline: float) -> CancelScope: """ Similar to :meth:`move_on_after`, except `deadline` is the absolute time to stop waiting, or :data:`math.inf`. @@ -918,7 +950,7 @@ async def main(): ... deadline = backend.current_time() + 10 - async with backend.move_on_at(deadline): + with backend.move_on_at(deadline): await long_running_operation(backend) print("After at most 10 seconds.") @@ -927,9 +959,9 @@ async def main(): deadline: absolute time to stop waiting. Returns: - an :term:`asynchronous context manager` + a new cancel scope. """ - raise NotImplementedError + return self.open_cancel_scope(deadline=deadline) @abstractmethod def current_time(self) -> float: @@ -978,31 +1010,6 @@ async def sleep_until(self, deadline: float) -> None: """ return await self.sleep(max(deadline - self.current_time(), 0)) - @abstractmethod - def spawn_task( - self, - coro_func: Callable[..., Coroutine[Any, Any, _T]], - /, - *args: Any, - context: contextvars.Context | None = ..., - ) -> SystemTask[_T]: - """ - Starts a new "system" task. - - It is a background task that runs concurrently with the current root task. - - Parameters: - coro_func: An async function. - args: Positional arguments to be passed to `coro_func`. If you need to pass keyword arguments, - then use :func:`functools.partial`. - context: If given, it must be a :class:`contextvars.Context` instance in which the coroutine should be executed. - If the framework does not support contexts (or does not use them), it must simply ignore this parameter. - - Returns: - the created task. - """ - raise NotImplementedError - @abstractmethod def create_task_group(self) -> TaskGroup: """ @@ -1373,3 +1380,16 @@ async def wait_future(self, future: concurrent.futures.Future[_T_co]) -> _T_co: Whatever returns ``future.result()`` """ raise NotImplementedError + + +def _timeout_after(backend: AsyncBackend, delay: float) -> contextlib._GeneratorContextManager[CancelScope]: + return _timeout_at(backend, backend.current_time() + delay) + + +@contextlib.contextmanager +def _timeout_at(backend: AsyncBackend, deadline: float) -> Iterator[CancelScope]: + with backend.move_on_at(deadline) as scope: + yield scope + + if scope.cancelled_caught(): + raise TimeoutError("timed out") diff --git a/src/easynetwork/api_async/backend/tasks.py b/src/easynetwork/api_async/backend/tasks.py deleted file mode 100644 index 4cb513a8..00000000 --- a/src/easynetwork/api_async/backend/tasks.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -"""Task utilities module""" - -from __future__ import annotations - -__all__ = ["SingleTaskRunner"] - -import functools -from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar - -if TYPE_CHECKING: - from .abc import AsyncBackend, SystemTask - - -_P = ParamSpec("_P") -_T_co = TypeVar("_T_co", covariant=True) - - -class SingleTaskRunner(Generic[_T_co]): - """ - An helper class to execute a coroutine function only once. - - In addition to one-time execution, concurrent calls will simply wait for the result:: - - async def expensive_task(): - print("Start expensive task") - - ... - - print("Done") - return 42 - - async def main(): - ... - - task_runner = SingleTaskRunner(backend, expensive_task) - async with backend.create_task_group() as task_group: - tasks = [task_group.start_soon(task_runner.run) for _ in range(10)] - - assert all(await t.join() == 42 for t in tasks) - """ - - __slots__ = ( - "__backend", - "__coro_func", - "__task", - "__weakref__", - ) - - def __init__( - self, - backend: AsyncBackend, - coro_func: Callable[_P, Coroutine[Any, Any, _T_co]], - /, - *args: _P.args, - **kwargs: _P.kwargs, - ) -> None: - """ - Parameters: - backend: The asynchronous backend interface. - coro_func: An async function. - args: Positional arguments to be passed to `coro_func`. - kwargs: Keyword arguments to be passed to `coro_func`. - """ - super().__init__() - - self.__backend: AsyncBackend = backend - self.__coro_func: Callable[[], Coroutine[Any, Any, _T_co]] | None = functools.partial( - coro_func, - *args, - **kwargs, - ) - self.__task: SystemTask[_T_co] | None = None - - def cancel(self) -> bool: - """ - Cancel coroutine execution. - - If the runner was not used yet, :meth:`run` will not call `coro_func` and raise ``backend.get_cancelled_exc_class()``. - - If `coro_func` is already running, a cancellation request is sent to the coroutine. - - Returns: - :data:`True` in case of success, :data:`False` otherwise. - """ - self.__coro_func = None - if self.__task is not None: - return self.__task.cancel() - return True - - async def run(self) -> _T_co: - """ - Executes the coroutine `coro_func`. - - Raises: - Exception: Whatever ``coro_func`` raises. - - Returns: - Whatever ``coro_func`` returns. - """ - must_cancel_inner_task: bool = False - if self.__task is None: - must_cancel_inner_task = True - if self.__coro_func is None: - self.__task = self.__backend.spawn_task(self.__backend.sleep_forever) - self.__task.cancel() - else: - coro_func = self.__coro_func - self.__coro_func = None - self.__task = self.__backend.spawn_task(coro_func) - del coro_func - - try: - if must_cancel_inner_task: - return await self.__task.join_or_cancel() - else: - return await self.__task.join() - finally: - del self # Avoid circular reference with raised exception (if any) diff --git a/src/easynetwork/api_async/client/abc.py b/src/easynetwork/api_async/client/abc.py index 2a7b79f6..7b291f58 100644 --- a/src/easynetwork/api_async/client/abc.py +++ b/src/easynetwork/api_async/client/abc.py @@ -228,7 +228,7 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter while True: try: - async with timeout_after(timeout): + with timeout_after(timeout): _start = perf_counter() packet = await self.recv_packet() _end = perf_counter() diff --git a/src/easynetwork/api_async/client/tcp.py b/src/easynetwork/api_async/client/tcp.py index 6f192c13..e7a019c0 100644 --- a/src/easynetwork/api_async/client/tcp.py +++ b/src/easynetwork/api_async/client/tcp.py @@ -19,10 +19,11 @@ __all__ = ["AsyncTCPNetworkClient"] import contextlib as _contextlib +import dataclasses as _dataclasses import errno as _errno import socket as _socket -from collections.abc import Callable, Iterator, Mapping -from typing import TYPE_CHECKING, Any, NoReturn, TypedDict, final, overload +from collections.abc import Awaitable, Callable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, NoReturn, final, overload try: import ssl as _ssl @@ -41,24 +42,41 @@ check_socket_family as _check_socket_family, check_socket_no_ssl as _check_socket_no_ssl, error_from_errno as _error_from_errno, + make_callback as _make_callback, ) from ...tools.constants import CLOSED_SOCKET_ERRNOS, MAX_STREAM_BUFSIZE, SSL_HANDSHAKE_TIMEOUT, SSL_SHUTDOWN_TIMEOUT from ...tools.socket import SocketAddress, SocketProxy, new_socket_address, set_tcp_keepalive, set_tcp_nodelay -from ..backend.abc import AsyncBackend, AsyncStreamSocketAdapter, ILock +from ..backend.abc import AsyncBackend, AsyncStreamSocketAdapter, CancelScope, ILock from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from .abc import AbstractAsyncNetworkClient if TYPE_CHECKING: import ssl as _typing_ssl -class _ClientInfo(TypedDict): +@_dataclasses.dataclass(kw_only=True, frozen=True, slots=True) +class _ClientInfo: proxy: SocketProxy local_address: SocketAddress remote_address: SocketAddress +@_dataclasses.dataclass(kw_only=True, slots=True) +class _SocketConnector: + lock: ILock + factory: Callable[[], Awaitable[tuple[AsyncStreamSocketAdapter, _ClientInfo]]] | None + scope: CancelScope + _result: tuple[AsyncStreamSocketAdapter, _ClientInfo] | None = _dataclasses.field(init=False, default=None) + + async def get(self) -> tuple[AsyncStreamSocketAdapter, _ClientInfo] | None: + async with self.lock: + factory, self.factory = self.factory, None + if factory is not None: + with self.scope: + self._result = await factory() + return self._result + + class AsyncTCPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPacketT]): """ An asynchronous network client interface for TCP connections. @@ -217,7 +235,7 @@ def __init__( def _value_or_default(value: float | None, default: float) -> float: return value if value is not None else default - self.__socket_connector: SingleTaskRunner[AsyncStreamSocketAdapter] | None = None + socket_factory: Callable[[], Awaitable[AsyncStreamSocketAdapter]] match __arg: case _socket.socket() as socket: _check_socket_family(socket.family) @@ -225,8 +243,7 @@ def _value_or_default(value: float | None, default: float) -> float: if ssl: if server_hostname is None: raise ValueError("You must set server_hostname when using ssl without a host") - self.__socket_connector = SingleTaskRunner( - backend, + socket_factory = _make_callback( backend.wrap_ssl_over_tcp_client_socket, socket, ssl_context=ssl, @@ -236,11 +253,10 @@ def _value_or_default(value: float | None, default: float) -> float: **kwargs, ) else: - self.__socket_connector = SingleTaskRunner(backend, backend.wrap_tcp_client_socket, socket, **kwargs) + socket_factory = _make_callback(backend.wrap_tcp_client_socket, socket, **kwargs) case (host, port): if ssl: - self.__socket_connector = SingleTaskRunner( - backend, + socket_factory = _make_callback( backend.create_ssl_over_tcp_connection, host, port, @@ -251,11 +267,16 @@ def _value_or_default(value: float | None, default: float) -> float: **kwargs, ) else: - self.__socket_connector = SingleTaskRunner(backend, backend.create_tcp_connection, host, port, **kwargs) + socket_factory = _make_callback(backend.create_tcp_connection, host, port, **kwargs) case _: # pragma: no cover raise TypeError("Invalid arguments") - assert self.__socket_connector is not None # nosec assert_used + self.__socket_connector: _SocketConnector | None = _SocketConnector( + lock=self.__backend.create_lock(), + factory=_make_callback(self.__create_socket, socket_factory), + scope=self.__backend.open_cancel_scope(), + ) + assert ssl_shared_lock is not None # nosec assert_used self.__receive_lock: ILock @@ -312,33 +333,26 @@ async def wait_connected(self) -> None: ConnectionError: could not connect to remote. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - if self.__socket is None: - socket_connector = self.__socket_connector - if socket_connector is None: - raise ClientClosedError("Client is closing, or is already closed") - socket = await socket_connector.run() - if self.__socket_connector is None: # wait_connected() or aclose() called in concurrency - return await self.__backend.cancel_shielded_coro_yield() - self.__socket = socket - self.__socket_connector = None - if self.__info is None: - self.__info = self.__build_info_dict(self.__socket) - socket_proxy = self.__info["proxy"] - with _contextlib.suppress(OSError): - set_tcp_nodelay(socket_proxy, True) - with _contextlib.suppress(OSError): - set_tcp_keepalive(socket_proxy, True) + await self.__ensure_connected() @staticmethod - def __build_info_dict(socket: AsyncStreamSocketAdapter) -> _ClientInfo: + async def __create_socket( + socket_factory: Callable[[], Awaitable[AsyncStreamSocketAdapter]], + ) -> tuple[AsyncStreamSocketAdapter, _ClientInfo]: + socket = await socket_factory() socket_proxy = SocketProxy(socket.socket()) local_address: SocketAddress = new_socket_address(socket_proxy.getsockname(), socket_proxy.family) remote_address: SocketAddress = new_socket_address(socket_proxy.getpeername(), socket_proxy.family) - return { - "proxy": socket_proxy, - "local_address": local_address, - "remote_address": remote_address, - } + info = _ClientInfo( + proxy=socket_proxy, + local_address=local_address, + remote_address=remote_address, + ) + with _contextlib.suppress(OSError): + set_tcp_nodelay(socket_proxy, True) + with _contextlib.suppress(OSError): + set_tcp_keepalive(socket_proxy, True) + return socket, info def is_closing(self) -> bool: """ @@ -372,7 +386,7 @@ async def aclose(self) -> None: Can be safely called multiple times. """ if self.__socket_connector is not None: - self.__socket_connector.cancel() + self.__socket_connector.scope.cancel() self.__socket_connector = None async with self.__send_lock: socket, self.__socket = self.__socket, None @@ -389,8 +403,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: """ Sends `packet` to the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Warning: In the case of a cancellation, it is impossible to know if all the packet data has been sent. This would leave the connection in an inconsistent state. @@ -437,8 +449,6 @@ async def recv_packet(self) -> _ReceivedPacketT: """ Waits for a new packet to arrive from the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Raises: ClientClosedError: the client object is closed. ConnectionError: connection unexpectedly closed during operation. @@ -492,7 +502,7 @@ def get_local_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["local_address"] + return self.__info.local_address def get_remote_address(self) -> SocketAddress: """ @@ -509,7 +519,7 @@ def get_remote_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["remote_address"] + return self.__info.remote_address def get_backend(self) -> AsyncBackend: return self.__backend @@ -517,12 +527,18 @@ def get_backend(self) -> AsyncBackend: get_backend.__doc__ = AbstractAsyncNetworkClient.get_backend.__doc__ async def __ensure_connected(self) -> AsyncStreamSocketAdapter: - await self.wait_connected() - assert self.__socket is not None # nosec assert_used - socket = self.__socket - if socket.is_closing(): + if self.__socket is None or self.__info is None: + socket_and_info = None + if (socket_connector := self.__socket_connector) is not None: + socket_and_info = await socket_connector.get() + self.__socket_connector = None + if socket_and_info is None: + raise ClientClosedError("Client is closing, or is already closed") + self.__socket, self.__info = socket_and_info + + if self.__socket.is_closing(): self.__abort(None) - return socket + return self.__socket @_contextlib.contextmanager def __convert_socket_error(self) -> Iterator[None]: @@ -550,7 +566,7 @@ def socket(self) -> SocketProxy: """ if self.__info is None: raise AttributeError("Socket not connected") - return self.__info["proxy"] + return self.__info.proxy @property @final diff --git a/src/easynetwork/api_async/client/udp.py b/src/easynetwork/api_async/client/udp.py index 24cb69db..2c15644b 100644 --- a/src/easynetwork/api_async/client/udp.py +++ b/src/easynetwork/api_async/client/udp.py @@ -19,12 +19,13 @@ __all__ = ["AsyncUDPNetworkClient", "AsyncUDPNetworkEndpoint"] import contextlib +import dataclasses as _dataclasses import errno as _errno import math import socket as _socket import time -from collections.abc import AsyncGenerator, AsyncIterator, Mapping -from typing import TYPE_CHECKING, Any, Generic, Self, TypedDict, final, overload +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping +from typing import TYPE_CHECKING, Any, Generic, Self, final, overload from ..._typevars import _ReceivedPacketT, _SentPacketT from ...exceptions import ClientClosedError, DatagramProtocolParseError @@ -35,24 +36,41 @@ check_socket_no_ssl as _check_socket_no_ssl, ensure_datagram_socket_bound as _ensure_datagram_socket_bound, error_from_errno as _error_from_errno, + make_callback as _make_callback, ) from ...tools.constants import MAX_DATAGRAM_BUFSIZE from ...tools.socket import SocketAddress, SocketProxy, new_socket_address -from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, ILock +from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, CancelScope, ILock from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from .abc import AbstractAsyncNetworkClient if TYPE_CHECKING: from types import TracebackType -class _EndpointInfo(TypedDict): +@_dataclasses.dataclass(kw_only=True, frozen=True, slots=True) +class _EndpointInfo: proxy: SocketProxy local_address: SocketAddress remote_address: SocketAddress | None +@_dataclasses.dataclass(kw_only=True, slots=True) +class _SocketConnector: + lock: ILock + factory: Callable[[], Awaitable[tuple[AsyncDatagramSocketAdapter, _EndpointInfo]]] | None + scope: CancelScope + _result: tuple[AsyncDatagramSocketAdapter, _EndpointInfo] | None = _dataclasses.field(init=False, default=None) + + async def get(self) -> tuple[AsyncDatagramSocketAdapter, _EndpointInfo] | None: + async with self.lock: + factory, self.factory = self.factory, None + if factory is not None: + with self.scope: + self._result = await factory() + return self._result + + class AsyncUDPNetworkEndpoint(Generic[_SentPacketT, _ReceivedPacketT]): """Asynchronous generic UDP endpoint interface. @@ -64,7 +82,7 @@ class AsyncUDPNetworkEndpoint(Generic[_SentPacketT, _ReceivedPacketT]): __slots__ = ( "__socket", "__backend", - "__socket_builder", + "__socket_connector", "__info", "__receive_lock", "__send_lock", @@ -136,18 +154,23 @@ def __init__( self.__backend: AsyncBackend = backend self.__info: _EndpointInfo | None = None - self.__socket_builder: SingleTaskRunner[AsyncDatagramSocketAdapter] | None = None + socket_factory: Callable[[], Awaitable[AsyncDatagramSocketAdapter]] match kwargs: case {"socket": _socket.socket() as socket, **kwargs}: _check_socket_family(socket.family) _check_socket_no_ssl(socket) _ensure_datagram_socket_bound(socket) - self.__socket_builder = SingleTaskRunner(backend, backend.wrap_udp_socket, socket, **kwargs) + socket_factory = _make_callback(backend.wrap_udp_socket, socket, **kwargs) case _: if kwargs.get("local_address") is None: kwargs["local_address"] = ("localhost", 0) - self.__socket_builder = SingleTaskRunner(backend, backend.create_udp_endpoint, **kwargs) + socket_factory = _make_callback(backend.create_udp_endpoint, **kwargs) + self.__socket_connector: _SocketConnector | None = _SocketConnector( + lock=self.__backend.create_lock(), + factory=_make_callback(self.__create_socket, socket_factory), + scope=self.__backend.open_cancel_scope(), + ) self.__receive_lock: ILock = backend.create_lock() self.__send_lock: ILock = backend.create_lock() @@ -211,20 +234,13 @@ async def wait_bound(self) -> None: ClientClosedError: the client object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - if self.__socket is None: - socket_builder = self.__socket_builder - if socket_builder is None: - raise ClientClosedError("Client is closing, or is already closed") - socket = await socket_builder.run() - if self.__socket_builder is None: # wait_bound() or aclose() called in concurrency - return await self.__backend.cancel_shielded_coro_yield() - self.__socket = socket - self.__socket_builder = None - if self.__info is None: - self.__info = self.__build_info_dict(self.__socket) + await self.__ensure_opened() @staticmethod - def __build_info_dict(socket: AsyncDatagramSocketAdapter) -> _EndpointInfo: + async def __create_socket( + socket_factory: Callable[[], Awaitable[AsyncDatagramSocketAdapter]], + ) -> tuple[AsyncDatagramSocketAdapter, _EndpointInfo]: + socket = await socket_factory() socket_proxy = SocketProxy(socket.socket()) local_address: SocketAddress = new_socket_address(socket_proxy.getsockname(), socket_proxy.family) if local_address.port == 0: @@ -234,11 +250,12 @@ def __build_info_dict(socket: AsyncDatagramSocketAdapter) -> _EndpointInfo: remote_address = new_socket_address(socket_proxy.getpeername(), socket_proxy.family) except OSError: remote_address = None - return { - "proxy": socket_proxy, - "local_address": local_address, - "remote_address": remote_address, - } + info = _EndpointInfo( + proxy=socket_proxy, + local_address=local_address, + remote_address=remote_address, + ) + return socket, info def is_closing(self) -> bool: """ @@ -252,7 +269,7 @@ def is_closing(self) -> bool: Returns: the endpoint state. """ - if self.__socket_builder is not None: + if self.__socket_connector is not None: return False socket = self.__socket return socket is None or socket.is_closing() @@ -270,9 +287,9 @@ async def aclose(self) -> None: Can be safely called multiple times. """ - if self.__socket_builder is not None: - self.__socket_builder.cancel() - self.__socket_builder = None + if self.__socket_connector is not None: + self.__socket_connector.scope.cancel() + self.__socket_connector = None async with self.__send_lock: socket, self.__socket = self.__socket, None if socket is None: @@ -292,8 +309,6 @@ async def send_packet_to( """ Sends `packet` to the remote endpoint `address`. Does not require task synchronization. - Calls :meth:`wait_bound`. - If a remote address is configured, `address` must be :data:`None` or the same as the remote address, otherwise `address` must not be :data:`None`. @@ -311,7 +326,7 @@ async def send_packet_to( async with self.__send_lock: socket = await self.__ensure_opened() assert self.__info is not None # nosec assert_used - if (remote_addr := self.__info["remote_address"]) is not None: + if (remote_addr := self.__info.remote_address) is not None: if address is not None: if new_socket_address(address, self.socket.family) != remote_addr: raise ValueError(f"Invalid address: must be None or {remote_addr}") @@ -326,8 +341,6 @@ async def recv_packet_from(self) -> tuple[_ReceivedPacketT, SocketAddress]: """ Waits for a new packet to arrive from another endpoint. Does not require task synchronization. - Calls :meth:`wait_bound`. - Raises: ClientClosedError: the endpoint object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. @@ -386,7 +399,7 @@ async def iter_received_packets_from( while True: try: - async with timeout_after(timeout): + with timeout_after(timeout): _start = perf_counter() packet_tuple = await self.recv_packet_from() _end = perf_counter() @@ -409,7 +422,7 @@ def get_local_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["local_address"] + return self.__info.local_address def get_remote_address(self) -> SocketAddress | None: """ @@ -424,14 +437,20 @@ def get_remote_address(self) -> SocketAddress | None: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["remote_address"] + return self.__info.remote_address def get_backend(self) -> AsyncBackend: return self.__backend async def __ensure_opened(self) -> AsyncDatagramSocketAdapter: - await self.wait_bound() - assert self.__socket is not None # nosec assert_used + if self.__socket is None or self.__info is None: + socket_and_info = None + if (socket_connector := self.__socket_connector) is not None: + socket_and_info = await socket_connector.get() + self.__socket_connector = None + if socket_and_info is None: + raise ClientClosedError("Client is closing, or is already closed") + self.__socket, self.__info = socket_and_info if self.__socket.is_closing(): raise _error_from_errno(_errno.ECONNABORTED) return self.__socket @@ -445,7 +464,7 @@ def socket(self) -> SocketProxy: """ if self.__info is None: raise AttributeError("Socket not connected") - return self.__info["proxy"] + return self.__info.proxy class AsyncUDPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPacketT], Generic[_SentPacketT, _ReceivedPacketT]): @@ -597,8 +616,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: """ Sends `packet` to the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Warning: In the case of a cancellation, it is impossible to know if all the packet data has been sent. @@ -609,7 +626,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: ClientClosedError: the client object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - await self.wait_connected() return await self.__endpoint.send_packet_to(packet, None) async def recv_packet(self) -> _ReceivedPacketT: @@ -626,12 +642,10 @@ async def recv_packet(self) -> _ReceivedPacketT: Returns: the received packet. """ - await self.wait_connected() packet, _ = await self.__endpoint.recv_packet_from() return packet - async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIterator[_ReceivedPacketT]: - await self.wait_connected() + async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncGenerator[_ReceivedPacketT, None]: async with contextlib.aclosing(self.__endpoint.iter_received_packets_from(timeout=timeout)) as generator: async for packet, _ in generator: yield packet diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index 125d08b6..a77acf87 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -54,7 +54,6 @@ set_tcp_nodelay, ) from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from ._tools.actions import ErrorAction as _ErrorAction, RequestAction as _RequestAction from .abc import AbstractAsyncNetworkServer, SupportsEventSet from .handler import AsyncStreamClient, AsyncStreamRequestHandler @@ -67,6 +66,7 @@ AsyncBackend, AsyncListenerSocketAdapter, AsyncStreamSocketAdapter, + CancelScope, IEvent, Task, TaskGroup, @@ -82,7 +82,7 @@ class AsyncTCPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp "__backend", "__listeners", "__listeners_factory", - "__listeners_factory_runner", + "__listeners_factory_scope", "__protocol", "__request_handler", "__is_shutdown", @@ -200,7 +200,7 @@ def _value_or_default(value: float | None, default: float) -> float: backlog=backlog, reuse_port=reuse_port, ) - self.__listeners_factory_runner: SingleTaskRunner[Sequence[AsyncListenerSocketAdapter]] | None = None + self.__listeners_factory_scope: CancelScope | None = None self.__backend: AsyncBackend = backend self.__listeners: tuple[AsyncListenerSocketAdapter, ...] | None = None @@ -239,7 +239,8 @@ def stop_listening(self) -> None: del listener_task async def server_close(self) -> None: - self.__kill_listener_factory_runner() + if self.__listeners_factory_scope is not None: + self.__listeners_factory_scope.cancel() self.__listeners_factory = None await self.__close_listeners() @@ -266,7 +267,6 @@ async def close_listener(listener: AsyncListenerSocketAdapter) -> None: await self.__backend.cancel_shielded_coro_yield() async def shutdown(self) -> None: - self.__kill_listener_factory_runner() if self.__mainloop_task is not None: self.__mainloop_task.cancel() if self.__shutdown_asked: @@ -280,10 +280,6 @@ async def shutdown(self) -> None: shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__ - def __kill_listener_factory_runner(self) -> None: - if self.__listeners_factory_runner is not None: - self.__listeners_factory_runner.cancel() - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: async with _contextlib.AsyncExitStack() as server_exit_stack: is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack()) @@ -300,14 +296,17 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Bind and activate assert self.__listeners is None # nosec assert_used - assert self.__listeners_factory_runner is None # nosec assert_used + assert self.__listeners_factory_scope is None # nosec assert_used if self.__listeners_factory is None: raise ServerClosedError("Closed server") try: - self.__listeners_factory_runner = SingleTaskRunner(self.__backend, self.__listeners_factory) - self.__listeners = tuple(await self.__listeners_factory_runner.run()) + with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: + await self.__backend.coro_yield() + self.__listeners = tuple(await self.__listeners_factory()) + if self.__listeners_factory_scope.cancelled_caught(): + raise ServerClosedError("Closed server") finally: - self.__listeners_factory_runner = None + self.__listeners_factory_scope = None if not self.__listeners: self.__listeners = None raise OSError("empty listeners list") @@ -470,9 +469,8 @@ async def __new_request_handler(self, client: _ConnectedClientAPI[_ResponseT]) - return request_handler_generator async def __force_close_stream_socket(self, socket: AsyncStreamSocketAdapter) -> None: - with _contextlib.suppress(OSError): - async with self.__backend.move_on_after(0): - await socket.aclose() + with self.__backend.move_on_after(0), _contextlib.suppress(OSError): + await socket.aclose() @classmethod def __have_errno(cls, exc: OSError | BaseExceptionGroup[OSError], errnos: set[int]) -> bool: diff --git a/src/easynetwork/api_async/server/udp.py b/src/easynetwork/api_async/server/udp.py index efcda9b4..2fa4803a 100644 --- a/src/easynetwork/api_async/server/udp.py +++ b/src/easynetwork/api_async/server/udp.py @@ -40,13 +40,12 @@ from ...tools.constants import MAX_DATAGRAM_BUFSIZE from ...tools.socket import SocketAddress, SocketProxy, new_socket_address from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from ._tools.actions import ErrorAction as _ErrorAction, RequestAction as _RequestAction from .abc import AbstractAsyncNetworkServer, SupportsEventSet from .handler import AsyncDatagramClient, AsyncDatagramRequestHandler if TYPE_CHECKING: - from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, ICondition, IEvent, ILock, Task, TaskGroup + from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, CancelScope, ICondition, IEvent, ILock, Task, TaskGroup _KT = TypeVar("_KT") _VT = TypeVar("_VT") @@ -61,7 +60,7 @@ class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp "__backend", "__socket", "__socket_factory", - "__socket_factory_runner", + "__socket_factory_scope", "__protocol", "__request_handler", "__is_shutdown", @@ -122,7 +121,7 @@ def __init__( remote_address=None, reuse_port=reuse_port, ) - self.__socket_factory_runner: SingleTaskRunner[AsyncDatagramSocketAdapter] | None = None + self.__socket_factory_scope: CancelScope | None = None self.__backend: AsyncBackend = backend self.__socket: AsyncDatagramSocketAdapter | None = None @@ -149,7 +148,8 @@ def is_serving(self) -> bool: is_serving.__doc__ = AbstractAsyncNetworkServer.is_serving.__doc__ async def server_close(self) -> None: - self.__kill_socket_factory_runner() + if self.__socket_factory_scope is not None: + self.__socket_factory_scope.cancel() self.__socket_factory = None await self.__close_socket() @@ -171,7 +171,6 @@ async def close_socket(socket: AsyncDatagramSocketAdapter) -> None: exit_stack.push_async_callback(self.__mainloop_task.wait) async def shutdown(self) -> None: - self.__kill_socket_factory_runner() if self.__mainloop_task is not None: self.__mainloop_task.cancel() if self.__shutdown_asked: @@ -185,10 +184,6 @@ async def shutdown(self) -> None: shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__ - def __kill_socket_factory_runner(self) -> None: - if self.__socket_factory_runner is not None: - self.__socket_factory_runner.cancel() - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: async with _contextlib.AsyncExitStack() as server_exit_stack: is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack()) @@ -205,14 +200,17 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Bind and activate assert self.__socket is None # nosec assert_used - assert self.__socket_factory_runner is None # nosec assert_used + assert self.__socket_factory_scope is None # nosec assert_used if self.__socket_factory is None: raise ServerClosedError("Closed server") try: - self.__socket_factory_runner = SingleTaskRunner(self.__backend, self.__socket_factory) - self.__socket = await self.__socket_factory_runner.run() + with self.__backend.open_cancel_scope() as self.__socket_factory_scope: + await self.__backend.coro_yield() + self.__socket = await self.__socket_factory() + if self.__socket_factory_scope.cancelled_caught() or self.__socket is None: + raise ServerClosedError("Closed server") finally: - self.__socket_factory_runner = None + self.__socket_factory_scope = None ################### # Final teardown diff --git a/src/easynetwork/api_sync/server/_base.py b/src/easynetwork/api_sync/server/_base.py index b8b244af..49f8809f 100644 --- a/src/easynetwork/api_sync/server/_base.py +++ b/src/easynetwork/api_sync/server/_base.py @@ -96,7 +96,7 @@ def shutdown(self, timeout: float | None = None) -> None: async def __do_shutdown_with_timeout(self, timeout_delay: float) -> None: backend = self.__server.get_backend() - async with backend.move_on_after(timeout_delay): + with backend.move_on_after(timeout_delay): await self.__server.shutdown() def serve_forever( diff --git a/src/easynetwork_asyncio/backend.py b/src/easynetwork_asyncio/backend.py index f9b7f8f7..9409662e 100644 --- a/src/easynetwork_asyncio/backend.py +++ b/src/easynetwork_asyncio/backend.py @@ -24,11 +24,12 @@ import contextvars import functools import itertools +import math import os import socket as _socket import sys from collections.abc import Callable, Coroutine, Mapping, Sequence -from contextlib import AbstractAsyncContextManager as AsyncContextManager, closing +from contextlib import closing from typing import TYPE_CHECKING, Any, NoReturn, ParamSpec, TypeVar try: @@ -47,7 +48,7 @@ from .datagram.socket import AsyncioTransportDatagramSocketAdapter, RawDatagramSocketAdapter from .stream.listener import AcceptedSocket, AcceptedSSLSocket, ListenerSocketAdapter from .stream.socket import AsyncioTransportStreamSocketAdapter, RawStreamSocketAdapter -from .tasks import SystemTask, TaskGroup, TaskUtils, TimeoutHandle +from .tasks import CancelScope, TaskGroup, TaskUtils from .threads import ThreadsPortal if TYPE_CHECKING: @@ -91,17 +92,8 @@ async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, _T_co]) -> _T raise TypeError("Expected a coroutine object") return await TaskUtils.cancel_shielded_await_task(asyncio.create_task(coroutine)) - def timeout(self, delay: float) -> AsyncContextManager[TimeoutHandle]: - return TaskUtils.timeout_after(delay) - - def timeout_at(self, deadline: float) -> AsyncContextManager[TimeoutHandle]: - return TaskUtils.timeout_at(deadline) - - def move_on_after(self, delay: float) -> AsyncContextManager[TimeoutHandle]: - return TaskUtils.move_on_after(delay) - - def move_on_at(self, deadline: float) -> AsyncContextManager[TimeoutHandle]: - return TaskUtils.move_on_at(deadline) + def open_cancel_scope(self, *, deadline: float = math.inf) -> CancelScope: + return CancelScope(deadline=deadline) def current_time(self) -> float: loop = asyncio.get_running_loop() @@ -115,15 +107,6 @@ async def sleep_forever(self) -> NoReturn: await loop.create_future() raise AssertionError("await an unused future cannot end in any other way than by cancellation") - def spawn_task( - self, - coro_func: Callable[..., Coroutine[Any, Any, _T]], - /, - *args: Any, - context: contextvars.Context | None = None, - ) -> SystemTask[_T]: - return SystemTask(coro_func(*args), context=context) - def create_task_group(self) -> TaskGroup: return TaskGroup() diff --git a/src/easynetwork_asyncio/tasks.py b/src/easynetwork_asyncio/tasks.py index 3f6df4ac..b1c97f5b 100644 --- a/src/easynetwork_asyncio/tasks.py +++ b/src/easynetwork_asyncio/tasks.py @@ -17,21 +17,21 @@ from __future__ import annotations -__all__ = ["Task", "TaskGroup", "TaskUtils", "TimeoutHandle"] +__all__ = ["CancelScope", "Task", "TaskGroup", "TaskUtils"] import asyncio import contextvars +import enum import math from collections import deque from collections.abc import Callable, Coroutine, Iterable -from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, final +from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, Self, TypeVar, final from weakref import WeakKeyDictionary from easynetwork.api_async.backend.abc import ( - SystemTask as AbstractSystemTask, + CancelScope as AbstractCancelScope, Task as AbstractTask, TaskGroup as AbstractTaskGroup, - TimeoutHandle as AbstractTimeoutHandle, ) if TYPE_CHECKING: @@ -43,6 +43,7 @@ _T_co = TypeVar("_T_co", covariant=True) +@final class Task(AbstractTask[_T_co]): __slots__ = ("__t", "__h") @@ -50,6 +51,9 @@ def __init__(self, task: asyncio.Task[_T_co]) -> None: self.__t: asyncio.Task[_T_co] = task self.__h: int | None = None + def __repr__(self) -> str: + return f"" + def __hash__(self) -> int: if (h := self.__h) is None: self.__h = h = hash((self.__class__, self.__t, 0xFF)) @@ -71,43 +75,17 @@ def cancelled(self) -> bool: async def wait(self) -> None: task = self.__t - try: - if task.done(): - return - await asyncio.wait({task}) - finally: - del task, self # This is needed to avoid circular reference with raised exception + await asyncio.wait({task}) async def join(self) -> _T_co: - # If the caller cancels the join() task, it should not stop the inner task - # e.g. when awaiting from an another task than the one which creates the TaskGroup, - # you want to stop joining the sub-task, not accidentally cancel it. - # It is primarily to avoid error prone code where tasks were not explicitly cancelled using task.cancel() task = self.__t try: return await asyncio.shield(task) finally: del task, self # This is needed to avoid circular reference with raised exception - @property - def _asyncio_task(self) -> asyncio.Task[_T_co]: - return self.__t - - -@final -class SystemTask(Task[_T_co], AbstractSystemTask[_T_co]): - __slots__ = () - - def __init__( - self, - coroutine: Coroutine[Any, Any, _T_co], - *, - context: contextvars.Context | None = None, - ) -> None: - super().__init__(asyncio.create_task(coroutine, context=context)) - async def join_or_cancel(self) -> _T_co: - task = self._asyncio_task + task = self.__t try: return await task finally: @@ -149,90 +127,176 @@ def start_soon( return Task(self.__asyncio_tg.create_task(coro_func(*args), context=context)) -@final -class TimeoutHandle(AbstractTimeoutHandle): - __slots__ = ("__handle", "__only_move_on", "__already_delayed_cancellation") +class _ScopeState(enum.Enum): + CREATED = "created" + ENTERED = "entered" + EXITED = "exited" + - __current_handle_dict: WeakKeyDictionary[asyncio.Task[Any], deque[TimeoutHandle]] = WeakKeyDictionary() - __delayed_task_cancel_dict: WeakKeyDictionary[asyncio.Task[Any], asyncio.Handle] = WeakKeyDictionary() +class _DelayedCancel(NamedTuple): + handle: asyncio.Handle + message: str | None - def __init__(self, handle: asyncio.Timeout, *, only_move_on: bool = False) -> None: + +@final +class CancelScope(AbstractCancelScope): + __slots__ = ( + "__host_task", + "__state", + "__cancel_called", + "__cancelled_caught", + "__task_cancelling", + "__deadline", + "__timeout_handle", + "__delayed_cancellation_on_enter", + ) + + __current_task_scope_dict: WeakKeyDictionary[asyncio.Task[Any], deque[CancelScope]] = WeakKeyDictionary() + __delayed_task_cancel_dict: WeakKeyDictionary[asyncio.Task[Any], _DelayedCancel] = WeakKeyDictionary() + + def __init__(self, *, deadline: float = math.inf) -> None: super().__init__() - self.__handle: asyncio.Timeout = handle - self.__only_move_on: bool = bool(only_move_on) - self.__already_delayed_cancellation: bool = True + self.__host_task: asyncio.Task[Any] | None = None + self.__state: _ScopeState = _ScopeState.CREATED + self.__cancel_called: bool = False + self.__cancelled_caught: bool = False + self.__task_cancelling: int = 0 + self.__deadline: float = math.inf + self.__timeout_handle: asyncio.TimerHandle | None = None + self.reschedule(deadline) + + def __repr__(self) -> str: + active = self.__state is _ScopeState.ENTERED + cancel_called = self.__cancel_called + cancelled_caught = self.__cancelled_caught + host_task = self.__host_task + deadline = self.__deadline + + info = f"{active=!r}, {cancelled_caught=!r}, {cancel_called=!r}, {host_task=!r}, {deadline=!r}" + return f"<{self.__class__.__name__}({info})>" + + def __enter__(self) -> Self: + if self.__state is not _ScopeState.CREATED: + raise RuntimeError("CancelScope entered twice") + + self.__host_task = current_task = TaskUtils.current_asyncio_task() + self.__task_cancelling = current_task.cancelling() + + current_task_scope = self.__current_task_scope_dict + if current_task not in current_task_scope: + current_task_scope[current_task] = deque() + current_task.add_done_callback(current_task_scope.pop) + current_task_scope[current_task].appendleft(self) + + self.__state = _ScopeState.ENTERED + + if self.__cancel_called: + self.__deliver_cancellation() + else: + self.__timeout() + return self - async def __aenter__(self) -> Self: - timeout_handle: asyncio.Timeout = self.__handle - await type(timeout_handle).__aenter__(timeout_handle) - current_task = TaskUtils.current_asyncio_task() - current_handle_dict = self.__current_handle_dict - if current_task not in current_handle_dict: - current_handle_dict[current_task] = deque() - current_task.add_done_callback(current_handle_dict.pop) - current_handle_dict[current_task].appendleft(self) + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool: + if self.__state is not _ScopeState.ENTERED: + raise RuntimeError("This cancel scope is not active") + + if TaskUtils.current_asyncio_task() is not self.__host_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + + if self._current_task_scope(self.__host_task) is not self: + raise RuntimeError("Attempted to exit a cancel scope that isn't the current tasks's current cancel scope") + + self.__state = _ScopeState.EXITED + + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + + host_task, self.__host_task = self.__host_task, None + self.__current_task_scope_dict[host_task].popleft() + + if self.__cancel_called: + task_cancelling = host_task.uncancel() + if isinstance(exc_val, asyncio.CancelledError): + self.__cancelled_caught = task_cancelling <= self.__task_cancelling or self.__cancellation_id() in exc_val.args + + delayed_task_cancel = self.__delayed_task_cancel_dict.get(host_task, None) + if delayed_task_cancel is not None and delayed_task_cancel.message == self.__cancellation_id(): + del self.__delayed_task_cancel_dict[host_task] + delayed_task_cancel.handle.cancel() + + current_task_scope = self._current_task_scope(host_task) + if current_task_scope is None: + if task_cancelling > 0: + self._reschedule_delayed_task_cancel(host_task, None) + else: + if current_task_scope.__cancel_called: + self._reschedule_delayed_task_cancel(host_task, current_task_scope.__cancellation_id()) + + return self.__cancelled_caught + + def __deliver_cancellation(self) -> None: + if self.__host_task is None: + # Scope not active. + return try: - self.__already_delayed_cancellation = not self.__delayed_task_cancel_dict[current_task].cancelled() + self.__delayed_task_cancel_dict.pop(self.__host_task).handle.cancel() except KeyError: - self.__already_delayed_cancellation = False - return self + pass + self.__host_task.cancel(msg=self.__cancellation_id()) - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - timeout_handle: asyncio.Timeout = self.__handle - current_task = TaskUtils.current_asyncio_task() - try: - await type(timeout_handle).__aexit__(timeout_handle, exc_type, exc_val, exc_tb) - except TimeoutError: - if self.__only_move_on: - return True - raise - else: - return None - finally: - delayed_task_cancel = None - try: - self.__current_handle_dict[current_task].popleft() - except LookupError: # pragma: no cover - pass - finally: - if not self.__already_delayed_cancellation: - if not self.__current_handle_dict.get(current_task): - delayed_task_cancel = self.__delayed_task_cancel_dict.pop(current_task, None) - if delayed_task_cancel is not None: - delayed_task_cancel.cancel() - del current_task, exc_val, exc_tb, self + def __cancellation_id(self) -> str: + return f"Cancelled by cancel scope {id(self):x}" - def when(self) -> float: - deadline: float | None = self.__handle.when() - return deadline if deadline is not None else math.inf + def cancel(self) -> None: + if not self.__cancel_called: + self.__cancel_called = True + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + self.__deliver_cancellation() - def reschedule(self, when: float) -> None: - return self.__handle.reschedule(self._cast_time(when)) + def cancel_called(self) -> bool: + return self.__cancel_called - def expired(self) -> bool: - return self.__handle.expired() + def cancelled_caught(self) -> bool: + return self.__cancelled_caught - @staticmethod - def _cast_time(time_value: float) -> float | None: - assert time_value is not None # nosec assert_used - return time_value if time_value != math.inf else None + def when(self) -> float: + return self.__deadline + + def reschedule(self, when: float, /) -> None: + if math.isnan(when): + raise ValueError("deadline is NaN") + self.__deadline = max(when, 0) + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + if self.__state is _ScopeState.ENTERED and not self.__cancel_called: + self.__timeout() + + def __timeout(self) -> None: + if self.__deadline != math.inf: + loop = asyncio.get_running_loop() + if loop.time() >= self.__deadline: + self.cancel() + else: + self.__timeout_handle = loop.call_at(self.__deadline, self.__timeout) + + @classmethod + def _current_task_scope(cls, task: asyncio.Task[Any]) -> CancelScope | None: + try: + return cls.__current_task_scope_dict[task][0] + except LookupError: + return None @classmethod def _reschedule_delayed_task_cancel(cls, task: asyncio.Task[Any], cancel_msg: str | None) -> asyncio.Handle: + if task in cls.__delayed_task_cancel_dict: + raise RuntimeError("CancelScope issue.") # pragma: no cover task_cancel_handle = task.get_loop().call_soon(cls.__cancel_task_unless_done, task, cancel_msg) - if cls.__current_handle_dict.get(task): - if task in cls.__delayed_task_cancel_dict: - cls.__delayed_task_cancel_dict[task].cancel() - cls.__delayed_task_cancel_dict[task] = task_cancel_handle - else: - assert task not in cls.__delayed_task_cancel_dict # nosec assert_used - cls.__delayed_task_cancel_dict[task] = task_cancel_handle - task.get_loop().call_soon(cls.__delayed_task_cancel_dict.pop, task, None) + cls.__delayed_task_cancel_dict[task] = _DelayedCancel(task_cancel_handle, cancel_msg) + task.get_loop().call_soon(cls.__delayed_task_cancel_dict.pop, task, None) return task_cancel_handle @staticmethod @@ -282,7 +346,7 @@ async def cancel_shielded_wait_asyncio_futures( task_cancel_msg = _get_cancelled_error_message(exc) if task_cancelled: - return TimeoutHandle._reschedule_delayed_task_cancel(current_task, task_cancel_msg) + return CancelScope._reschedule_delayed_task_cancel(current_task, task_cancel_msg) return None finally: del current_task, fs, abort_func @@ -294,7 +358,7 @@ async def cancel_shielded_coro_yield(cls) -> None: try: await asyncio.sleep(0) except asyncio.CancelledError as exc: - TimeoutHandle._reschedule_delayed_task_cancel(current_task, _get_cancelled_error_message(exc)) + CancelScope._reschedule_delayed_task_cancel(current_task, _get_cancelled_error_message(exc)) finally: del current_task @@ -311,22 +375,6 @@ async def cancel_shielded_await_task(cls, task: asyncio.Task[_T_co]) -> _T_co: finally: del task - @classmethod - def timeout_after(cls, delay: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout(TimeoutHandle._cast_time(delay))) - - @classmethod - def timeout_at(cls, deadline: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout_at(TimeoutHandle._cast_time(deadline))) - - @classmethod - def move_on_after(cls, delay: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout(TimeoutHandle._cast_time(delay)), only_move_on=True) - - @classmethod - def move_on_at(cls, deadline: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout_at(TimeoutHandle._cast_time(deadline)), only_move_on=True) - def _get_cancelled_error_message(exc: asyncio.CancelledError) -> str | None: msg: str | None diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index 4e222bf1..fe25826e 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -5,7 +5,8 @@ import time from collections.abc import Awaitable, Callable from concurrent.futures import CancelledError as FutureCancelledError, Future, wait as wait_concurrent_futures -from typing import TYPE_CHECKING, Any +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Literal from easynetwork.api_async.backend.factory import AsyncBackendFactory from easynetwork_asyncio.backend import AsyncIOBackend @@ -149,7 +150,7 @@ async def test____timeout____respected( self, backend: AsyncIOBackend, ) -> None: - async with backend.timeout(1): + with backend.timeout(1): assert await asyncio.sleep(0.5, 42) == 42 async def test____timeout____timeout_error( @@ -157,7 +158,7 @@ async def test____timeout____timeout_error( backend: AsyncIOBackend, ) -> None: with pytest.raises(TimeoutError): - async with backend.timeout(0.25): + with backend.timeout(0.25): await asyncio.sleep(0.5, 42) async def test____timeout____cancellation( @@ -166,7 +167,7 @@ async def test____timeout____cancellation( backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.timeout(0.25): + with backend.timeout(0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -180,7 +181,7 @@ async def test____timeout_at____respected( event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async with backend.timeout_at(event_loop.time() + 1): + with backend.timeout_at(event_loop.time() + 1): assert await asyncio.sleep(0.5, 42) == 42 async def test____timeout_at____timeout_error( @@ -189,7 +190,7 @@ async def test____timeout_at____timeout_error( backend: AsyncIOBackend, ) -> None: with pytest.raises(TimeoutError): - async with backend.timeout_at(event_loop.time() + 0.25): + with backend.timeout_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) async def test____timeout_at____cancellation( @@ -198,7 +199,7 @@ async def test____timeout_at____cancellation( backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.timeout_at(event_loop.time() + 0.25): + with backend.timeout_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -211,19 +212,19 @@ async def test____move_on_after____respected( self, backend: AsyncIOBackend, ) -> None: - async with backend.move_on_after(1) as handle: + with backend.move_on_after(1) as scope: assert await asyncio.sleep(0.5, 42) == 42 - assert not handle.expired() + assert not scope.cancelled_caught() async def test____move_on_after____timeout_error( self, backend: AsyncIOBackend, ) -> None: - async with backend.move_on_after(0.25) as handle: + with backend.move_on_after(0.25) as scope: await asyncio.sleep(0.5, 42) - assert handle.expired() + assert scope.cancelled_caught() async def test____move_on_after____cancellation( self, @@ -231,7 +232,7 @@ async def test____move_on_after____cancellation( backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.move_on_after(0.25): + with backend.move_on_after(0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -245,20 +246,20 @@ async def test____move_on_at____respected( event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async with backend.move_on_at(event_loop.time() + 1) as handle: + with backend.move_on_at(event_loop.time() + 1) as scope: assert await asyncio.sleep(0.5, 42) == 42 - assert not handle.expired() + assert not scope.cancelled_caught() async def test____move_on_at____timeout_error( self, event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async with backend.move_on_at(event_loop.time() + 0.25) as handle: + with backend.move_on_at(event_loop.time() + 0.25) as scope: await asyncio.sleep(0.5, 42) - assert handle.expired() + assert scope.cancelled_caught() async def test____move_on_at____cancellation( self, @@ -266,7 +267,7 @@ async def test____move_on_at____cancellation( backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.move_on_at(event_loop.time() + 0.25): + with backend.move_on_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -287,61 +288,130 @@ async def test____sleep_forever____sleep_until_cancellation( with pytest.raises(asyncio.CancelledError): await sleep_task - async def test____spawn_task____run_coroutine_in_background( + async def test____open_cancel_scope____unbound_cancel_scope____cancel_when_entering( self, + event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - return await asyncio.sleep(0.5, value) + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + scope = backend.open_cancel_scope() + scope.cancel() + assert scope.cancel_called() - task = backend.spawn_task(coroutine, 42) - await asyncio.sleep(1) - assert task.done() - assert await task.join() == 42 + assert current_task.cancelling() == 0 + await asyncio.sleep(0.1) + + with scope: + assert current_task.cancelling() == 1 + scope.cancel() + assert current_task.cancelling() == 1 + await backend.coro_yield() - async def test____spawn_task____task_cancellation( + assert current_task.cancelling() == 0 + assert scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + + async def test____open_cancel_scope____unbound_cancel_scope____deadline_scheduled_when_entering( self, event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - return await asyncio.sleep(0.5, value) + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None - task = backend.spawn_task(coroutine, 42) + scope = backend.open_cancel_scope() + scope.reschedule(event_loop.time() + 1) - event_loop.call_later(0.2, task.cancel) + assert current_task.cancelling() == 0 + await asyncio.sleep(0.5) - with pytest.raises(asyncio.CancelledError): - await task.join() + with backend.timeout(0.6): + with scope: + assert current_task.cancelling() == 0 + await backend.sleep_forever() - assert task.cancelled() + assert scope.cancelled_caught() - async def test____spawn_task____exception( + await event_loop.create_task(coroutine()) + + async def test____open_cancel_scope____overwrite_defined_deadline( self, + event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - await asyncio.sleep(0.1) - raise ZeroDivisionError + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + with backend.move_on_after(1) as scope: + await backend.sleep(0.5) + scope.deadline += 1 + await backend.sleep(1) + del scope.deadline + assert scope.deadline == float("+inf") + await backend.sleep(1) - task = backend.spawn_task(coroutine, 42) + assert not scope.cancelled_caught() - with pytest.raises(ZeroDivisionError): - await task.join() + await event_loop.create_task(coroutine()) - async def test____spawn_task____with_context( + async def test____open_cancel_scope____invalid_deadline( self, backend: AsyncIOBackend, ) -> None: - async def coroutine(value: str) -> None: - cvar_for_test.set(value) + with pytest.raises(ValueError): + _ = backend.open_cancel_scope(deadline=float("nan")) + + async def test____open_cancel_scope____context_reuse( + self, + backend: AsyncIOBackend, + ) -> None: + with backend.open_cancel_scope() as scope: + with pytest.raises(RuntimeError, match=r"^CancelScope entered twice$"): + with scope: + ... + + with pytest.raises(RuntimeError, match=r"^CancelScope entered twice$"): + with scope: + ... + + async def test____open_cancel_scope____context_exit_before_enter( + self, + backend: AsyncIOBackend, + ) -> None: + with pytest.raises(RuntimeError, match=r"^This cancel scope is not active$"), ExitStack() as stack: + stack.push(backend.open_cancel_scope()) + + async def test____open_cancel_scope____task_misnesting( + self, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> ExitStack: + stack = ExitStack() + stack.enter_context(backend.open_cancel_scope()) + return stack + + stack = await asyncio.create_task(coroutine()) + with pytest.raises(RuntimeError, match=r"^Attempted to exit cancel scope in a different task than it was entered in$"): + stack.close() - cvar_for_test.set("something") - ctx = contextvars.copy_context() - task = backend.spawn_task(coroutine, "other", context=ctx) - await task.wait() - assert cvar_for_test.get() == "something" - assert ctx.run(cvar_for_test.get) == "other" + async def test____open_cancel_scope____scope_misnesting( + self, + backend: AsyncIOBackend, + ) -> None: + stack = ExitStack() + stack.enter_context(backend.open_cancel_scope()) + with backend.open_cancel_scope(): + with pytest.raises( + RuntimeError, match=r"^Attempted to exit a cancel scope that isn't the current tasks's current cancel scope$" + ): + stack.close() + stack.pop_all() async def test____create_task_group____task_pool( self, @@ -402,8 +472,10 @@ async def coroutine(value: int) -> int: # Tasks cannot be cancelled twice assert not task_42.cancel() + @pytest.mark.parametrize("join_method", ["join", "join_or_cancel"]) async def test____create_task_group____task_join_cancel_shielding( self, + join_method: Literal["join", "join_or_cancel"], event_loop: asyncio.AbstractEventLoop, backend: AsyncIOBackend, ) -> None: @@ -413,15 +485,24 @@ async def coroutine(value: int) -> int: async with backend.create_task_group() as task_group: inner_task = task_group.start_soon(coroutine, 42) - outer_task = event_loop.create_task(inner_task.join()) + match join_method: + case "join": + outer_task = event_loop.create_task(inner_task.join()) + case "join_or_cancel": + outer_task = event_loop.create_task(inner_task.join_or_cancel()) + case _: + pytest.fail("invalid argument") event_loop.call_later(0.2, outer_task.cancel) with pytest.raises(asyncio.CancelledError): await outer_task assert outer_task.cancelled() - assert not inner_task.cancelled() - assert await inner_task.join() == 42 + if join_method == "join_or_cancel": + assert inner_task.cancelled() + else: + assert not inner_task.cancelled() + assert await inner_task.join() == 42 async def test____create_task_group____start_soon_with_context( self, @@ -987,11 +1068,12 @@ async def test____cancel_shielded_coroutine____do_not_cancel_at_timeout_end( checkpoints: list[str] = [] async def coroutine(value: int) -> int: - async with backend.timeout(0) as handle: + with backend.timeout(0) as scope: await cancel_shielded_coroutine() checkpoints.append("cancel_shielded_coroutine") - assert handle.expired() # Context manager did not raise but effectively tried to cancel the task + assert scope.cancel_called() + assert not scope.cancelled_caught() await backend.coro_yield() checkpoints.append("coro_yield") return value @@ -1013,9 +1095,9 @@ async def coroutine(value: int) -> int: current_task = asyncio.current_task() assert current_task is not None - async with backend.move_on_after(0) as handle: - async with backend.timeout(0): - async with backend.timeout(0): + with backend.move_on_after(0) as scope: + with backend.timeout(0): + with backend.timeout(0): await cancel_shielded_coroutine() checkpoints.append("inner_cancel_shielded_coroutine") assert current_task.cancelling() == 3 @@ -1031,7 +1113,8 @@ async def coroutine(value: int) -> int: assert current_task.cancelling() == 0 - assert handle.expired() + assert scope.cancel_called() + assert scope.cancelled_caught() await backend.coro_yield() checkpoints.append("coro_yield") return value @@ -1056,7 +1139,7 @@ async def coroutine(value: int) -> int: await cancel_shielded_coroutine() checkpoints.append("cancel_shielded_coroutine") - async with backend.timeout(0): + with backend.timeout(0): await cancel_shielded_coroutine() checkpoints.append("inner_cancel_shielded_coroutine") assert current_task.cancelling() == 2 @@ -1073,6 +1156,89 @@ async def coroutine(value: int) -> int: await task assert checkpoints == ["cancel_shielded_coroutine", "inner_cancel_shielded_coroutine"] + async def test____cancel_shielded_coroutine____cancel_at_timeout_end_if_task_cancellation_does_not_come_from_scope( + self, + cancel_shielded_coroutine: Callable[[], Awaitable[Any]], + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + checkpoints: list[str] = [] + + async def coroutine(value: int) -> int: + current_task = asyncio.current_task() + assert current_task is not None + + with backend.open_cancel_scope(): + with backend.open_cancel_scope(): + await cancel_shielded_coroutine() + checkpoints.append("cancel_shielded_coroutine") + assert current_task.cancelling() == 1 + + assert current_task.cancelling() == 1 + await backend.coro_yield() + checkpoints.append("should_not_be_here") + return value + + task = event_loop.create_task(coroutine(42)) + event_loop.call_soon(task.cancel) + + with pytest.raises(asyncio.CancelledError): + await task + assert checkpoints == ["cancel_shielded_coroutine"] + + async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_1( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + inner_scope = backend.open_cancel_scope() + with backend.move_on_after(0.5) as outer_scope: + with inner_scope: + await backend.ignore_cancellation(backend.sleep(1)) + inner_scope.cancel() + await backend.coro_yield() + + assert outer_scope.cancel_called() + assert inner_scope.cancel_called() + + assert not outer_scope.cancelled_caught() + assert inner_scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + + async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_2( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + outer_scope = backend.move_on_after(1.5) + inner_scope = backend.move_on_after(0.5) + with outer_scope: + with inner_scope: + await backend.ignore_cancellation(backend.sleep(1)) + assert not inner_scope.cancelled_caught() + try: + await backend.coro_yield() + except asyncio.CancelledError: + pytest.fail("Cancelled") + await backend.sleep(1) + + assert outer_scope.cancel_called() + assert inner_scope.cancel_called() + + assert outer_scope.cancelled_caught() + assert not inner_scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + async def test____ignore_cancellation____do_not_reschedule_if_inner_task_cancelled_itself( self, event_loop: asyncio.AbstractEventLoop, diff --git a/tests/functional_test/test_async/test_backend/test_tasks.py b/tests/functional_test/test_async/test_backend/test_tasks.py deleted file mode 100644 index 5414ec31..00000000 --- a/tests/functional_test/test_async/test_backend/test_tasks.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, Any - -from easynetwork.api_async.backend.abc import AsyncBackend -from easynetwork.api_async.backend.factory import AsyncBackendFactory -from easynetwork.api_async.backend.tasks import SingleTaskRunner - -import pytest - -if TYPE_CHECKING: - from pytest_mock import MockerFixture - - -@pytest.mark.asyncio -class TestSingleTaskRunner: - @pytest.fixture - @staticmethod - def backend() -> AsyncBackend: - return AsyncBackendFactory.new("asyncio") - - async def test____run____run_task_once( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, return_value=mocker.sentinel.task_result) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 1, 2, 3, arg1=1, arg2=2, arg3=3) - - result1 = await runner.run() - result2 = await runner.run() - - assert result1 is mocker.sentinel.task_result - assert result2 is mocker.sentinel.task_result - coro_func.assert_awaited_once_with(1, 2, 3, arg1=1, arg2=2, arg3=3) - - async def test____run____early_cancel( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, return_value=mocker.sentinel.task_result) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 1, 2, 3, arg1=1, arg2=2, arg3=3) - - assert runner.cancel() - assert runner.cancel() # Cancel twice does nothing - - with pytest.raises(asyncio.CancelledError): - await runner.run() - - coro_func.assert_not_awaited() - coro_func.assert_not_called() - - async def test____run____cancel_while_running(self, backend: AsyncBackend) -> None: - async def coro_func(value: int) -> int: - return await asyncio.sleep(3600, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - await asyncio.sleep(0) - assert not task.done() - - assert runner.cancel() - assert runner.cancel() # Cancel twice does nothing - - with pytest.raises(asyncio.CancelledError): - await runner.run() - - assert task.cancelled() - - async def test____run____unhandled_exceptions( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - my_exc = OSError() - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, side_effect=[my_exc]) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func) - - with pytest.raises(OSError) as exc_info_run_1: - _ = await runner.run() - with pytest.raises(OSError) as exc_info_run_2: - _ = await runner.run() - - assert exc_info_run_1.value is my_exc - assert exc_info_run_2.value is my_exc - - async def test____run____waiting_task_is_cancelled(self, backend: AsyncBackend) -> None: - inner_task: list[asyncio.Task[int] | None] = [] - - async def coro_func(value: int) -> int: - inner_task.append(asyncio.current_task()) - return await asyncio.sleep(3600, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - await asyncio.sleep(0.1) - - task.cancel() - - with pytest.raises(asyncio.CancelledError): - await task - - assert inner_task[0] is not None and inner_task[0].cancelled() - - async def test____run____waiting_task_is_cancelled____not_the_first_runner(self, backend: AsyncBackend) -> None: - async def coro_func(value: int) -> int: - return await asyncio.sleep(0.5, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - task_2 = asyncio.create_task(runner.run()) - await asyncio.sleep(0.1) - - task_2.cancel() - - assert await task == 42 - with pytest.raises(asyncio.CancelledError): - await task_2 diff --git a/tests/functional_test/test_communication/test_async/test_server/base.py b/tests/functional_test/test_communication/test_async/test_server/base.py index 6c8c90ac..738c7d6c 100644 --- a/tests/functional_test/test_communication/test_async/test_server/base.py +++ b/tests/functional_test/test_communication/test_async/test_server/base.py @@ -67,6 +67,22 @@ async def test____serve_forever____shutdown_during_setup( await server.shutdown() assert event.is_set() + async def test____serve_forever____server_close_during_setup( + self, + server: AbstractAsyncNetworkServer, + ) -> None: + event = asyncio.Event() + server_task = None + with pytest.raises(ExceptionGroup): + async with asyncio.TaskGroup() as tg: + server_task = tg.create_task(server.serve_forever(is_up_event=event)) + await asyncio.sleep(0) + assert not event.is_set() + await server.server_close() + assert not event.is_set() + assert server_task is not None + assert isinstance(server_task.exception(), ServerClosedError) + async def test____serve_forever____without_is_up_event( self, server: AbstractAsyncNetworkServer, diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 29f46a23..277f4c59 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -167,7 +167,7 @@ async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, s await client.send_packet(request) try: with pytest.raises(TimeoutError): - async with self.backend.timeout(self.request_timeout): + with self.backend.timeout(self.request_timeout): yield await client.send_packet("successfully timed out") finally: @@ -195,7 +195,7 @@ async def on_connection(self, client: AsyncStreamClient[str]) -> AsyncGenerator[ if self.bypass_handshake: return try: - async with self.backend.timeout(1): + with self.backend.timeout(1): password = yield except TimeoutError: await client.send_packet("timeout error") diff --git a/tests/unit_test/test_async/conftest.py b/tests/unit_test/test_async/conftest.py index 9a31e7ce..8a08211c 100644 --- a/tests/unit_test/test_async/conftest.py +++ b/tests/unit_test/test_async/conftest.py @@ -32,14 +32,13 @@ def fake_cancellation_cls() -> type[BaseException]: @pytest.fixture def mock_backend(fake_cancellation_cls: type[BaseException], mocker: MockerFixture) -> MagicMock: - from easynetwork_asyncio.tasks import SystemTask, TaskGroup + from easynetwork_asyncio.tasks import TaskGroup from .._utils import AsyncDummyLock mock_backend = mocker.NonCallableMagicMock(spec=AsyncBackend) mock_backend.get_cancelled_exc_class.return_value = fake_cancellation_cls - mock_backend.spawn_task = lambda coro_func, *args, **kwargs: SystemTask(coro_func(*args), **kwargs) mock_backend.create_lock = AsyncDummyLock mock_backend.create_event = asyncio.Event mock_backend.create_task_group = TaskGroup diff --git a/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py b/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py index 1461ddf7..d1c85cef 100644 --- a/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py +++ b/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Coroutine, Sequence from socket import socket as Socket -from typing import Any, AsyncContextManager, NoReturn, final +from typing import Any, NoReturn, final from easynetwork.api_async.backend.abc import ( AsyncBackend, @@ -12,10 +12,8 @@ ICondition, IEvent, ILock, - SystemTask, TaskGroup, ThreadsPortal, - TimeoutHandle, ) @@ -41,22 +39,22 @@ async def cancel_shielded_coro_yield(self) -> None: async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, Any]) -> Any: raise NotImplementedError - def timeout(self, delay: Any) -> AsyncContextManager[TimeoutHandle]: + def timeout(self, delay: Any) -> Any: raise NotImplementedError - def timeout_at(self, deadline: Any) -> AsyncContextManager[TimeoutHandle]: + def timeout_at(self, deadline: Any) -> Any: raise NotImplementedError - def move_on_after(self, delay: Any) -> AsyncContextManager[TimeoutHandle]: + def move_on_after(self, delay: Any) -> Any: raise NotImplementedError - def move_on_at(self, deadline: Any) -> AsyncContextManager[TimeoutHandle]: + def move_on_at(self, deadline: Any) -> Any: raise NotImplementedError - def get_cancelled_exc_class(self) -> type[BaseException]: + def open_cancel_scope(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError - def spawn_task(self, *args: Any, **kwargs: Any) -> SystemTask[Any]: + def get_cancelled_exc_class(self) -> type[BaseException]: raise NotImplementedError def create_task_group(self) -> TaskGroup: diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py b/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py index fb821898..b81801bd 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py @@ -1,16 +1,14 @@ from __future__ import annotations import asyncio -import math -from collections.abc import Callable from typing import TYPE_CHECKING, Any -from easynetwork_asyncio.tasks import Task, TaskUtils, TimeoutHandle +from easynetwork_asyncio.tasks import Task, TaskUtils import pytest if TYPE_CHECKING: - from unittest.mock import AsyncMock, MagicMock + from unittest.mock import AsyncMock from pytest_mock import MockerFixture @@ -123,7 +121,7 @@ async def test____wait____task_already_done( await task.wait() # Assert - mock_asyncio_wait.assert_not_called() + mock_asyncio_wait.assert_awaited_once_with({mock_asyncio_task}) @pytest.mark.asyncio async def test____join____await_task( @@ -147,199 +145,6 @@ async def test____join____await_task( assert result is mocker.sentinel.task_result -class TestTimeout: - @pytest.fixture - @staticmethod - def mock_asyncio_timeout_handle(mocker: MockerFixture) -> MagicMock: - mock = mocker.NonCallableMagicMock(spec=asyncio.Timeout) - mock.__aenter__.return_value = mock - mock.when.return_value = None - mock.reschedule.return_value = None - mock.expired.return_value = False - return mock - - @pytest.fixture - @staticmethod - def timeout_handle(mock_asyncio_timeout_handle: MagicMock) -> TimeoutHandle: - return TimeoutHandle(mock_asyncio_timeout_handle) - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout", [TaskUtils.timeout_after, TaskUtils.move_on_after]) - async def test____timeout____schedule_timeout( - self, - timeout: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout = mocker.patch("asyncio.timeout", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout(123456789): - pass - - # Assert - mock_timeout.assert_called_once_with(123456789.0) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout", [TaskUtils.timeout_after, TaskUtils.move_on_after]) - async def test____timeout____handle_infinite( - self, - timeout: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout = mocker.patch("asyncio.timeout", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout(math.inf): - pass - - # Assert - mock_timeout.assert_called_once_with(None) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout_at", [TaskUtils.timeout_at, TaskUtils.move_on_at]) - async def test____timeout_at____schedule_timeout( - self, - timeout_at: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout_at = mocker.patch("asyncio.timeout_at", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout_at(123456789): - pass - - # Assert - mock_timeout_at.assert_called_once_with(123456789.0) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout_at", [TaskUtils.timeout_at, TaskUtils.move_on_at]) - async def test____timeout_at____handle_infinite( - self, - timeout_at: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout_at = mocker.patch("asyncio.timeout_at", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout_at(math.inf): - pass - - # Assert - mock_timeout_at.assert_called_once_with(None) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - def test____timeout_handle____when____return_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = 123456.789 - - # Act - deadline: float = timeout_handle.when() - - # Assert - assert deadline == 123456.789 - - def test____timeout_handle____when____infinite_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = None - - # Act - deadline: float = timeout_handle.when() - - # Assert - assert deadline == math.inf - - def test____timeout_handle____reschedule____set_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - - # Act - timeout_handle.reschedule(123456.789) - - # Assert - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(123456.789) - - def test____timeout_handle____reschedule____infinite_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - - # Act - timeout_handle.reschedule(math.inf) - - # Assert - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(None) - - @pytest.mark.parametrize("expected_retval", [False, True]) - def test____expired____expected_return_value( - self, - expected_retval: bool, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.expired.return_value = expected_retval - - # Act - expired: bool = timeout_handle.expired() - - # Assert - assert expired is expected_retval - - def test____deadline_property____set( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = 4 - - # Act - timeout_handle.deadline += 12 - - # Assert - mock_asyncio_timeout_handle.when.assert_called_once_with() - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(16) - - def test____deadline_property____delete( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = 4 - - # Act - del timeout_handle.deadline - - # Assert - mock_asyncio_timeout_handle.when.assert_not_called() - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(None) - - class TestTaskUtils: def test____current_asyncio_task____return_current_task(self) -> None: # Arrange