From c58962b4d9431bf5f6913c5ea3a3affa2a626d98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sun, 24 Sep 2023 19:30:58 +0200 Subject: [PATCH] Improved task management in the backend API (#129) --- docs/source/api/async/backend.rst | 32 +- pyproject.toml | 3 - src/easynetwork/api_async/backend/abc.py | 494 +++++++------- src/easynetwork/api_async/backend/factory.py | 9 +- src/easynetwork/api_async/backend/futures.py | 53 +- src/easynetwork/api_async/backend/tasks.py | 133 ---- src/easynetwork/api_async/client/abc.py | 2 +- src/easynetwork/api_async/client/tcp.py | 136 ++-- src/easynetwork/api_async/client/udp.py | 116 ++-- src/easynetwork/api_async/server/tcp.py | 52 +- src/easynetwork/api_async/server/udp.py | 36 +- src/easynetwork/api_sync/server/_base.py | 121 ++-- src/easynetwork/tools/_utils.py | 21 +- src/easynetwork_asyncio/__init__.py | 6 +- src/easynetwork_asyncio/backend.py | 128 +--- src/easynetwork_asyncio/datagram/endpoint.py | 51 +- src/easynetwork_asyncio/datagram/socket.py | 20 - src/easynetwork_asyncio/runner.py | 49 -- src/easynetwork_asyncio/socket.py | 8 +- src/easynetwork_asyncio/stream/listener.py | 5 +- src/easynetwork_asyncio/stream/socket.py | 57 +- src/easynetwork_asyncio/tasks.py | 355 +++++++--- src/easynetwork_asyncio/threads.py | 155 +++-- .../test_backend/test_asyncio_backend.py | 644 +++++++++++++----- .../test_backend/test_backend_factory.py | 12 +- .../test_async/test_backend/test_futures.py | 48 ++ .../test_async/test_backend/test_tasks.py | 123 ---- .../test_async/test_server/base.py | 16 + .../test_async/test_server/test_tcp.py | 8 +- .../test_sync/test_server/test_standalone.py | 45 +- tests/unit_test/test_async/base.py | 34 - tests/unit_test/test_async/conftest.py | 22 +- .../test_api/test_backend/_fake_backends.py | 19 +- .../test_api/test_backend/test_backend.py | 4 +- .../test_api/test_backend/test_futures.py | 93 ++- .../test_async/test_api/test_client/base.py | 4 +- .../test_api/test_client/test_tcp.py | 88 +-- .../test_api/test_client/test_udp.py | 49 +- .../test_asyncio_backend/test_backend.py | 122 +++- .../test_asyncio_backend/test_datagram.py | 146 ++-- .../test_asyncio_backend/test_stream.py | 191 +----- .../test_asyncio_backend/test_tasks.py | 214 +----- .../test_asyncio_backend/test_threads.py | 156 ----- tests/unit_test/test_tools/test_utils.py | 26 +- 44 files changed, 1842 insertions(+), 2264 deletions(-) delete mode 100644 src/easynetwork/api_async/backend/tasks.py delete mode 100644 src/easynetwork_asyncio/runner.py delete mode 100644 tests/functional_test/test_async/test_backend/test_tasks.py delete mode 100644 tests/unit_test/test_async/base.py delete mode 100644 tests/unit_test/test_async/test_asyncio_backend/test_threads.py diff --git a/docs/source/api/async/backend.rst b/docs/source/api/async/backend.rst index 6ce7807d..24711597 100644 --- a/docs/source/api/async/backend.rst +++ b/docs/source/api/async/backend.rst @@ -32,12 +32,6 @@ Runners .. automethod:: AsyncBackend.bootstrap -.. automethod:: AsyncBackend.new_runner - -.. autoclass:: Runner - :members: - :special-members: __enter__, __exit__ - Coroutines And Tasks -------------------- @@ -78,17 +72,11 @@ Creating Concurrent Tasks .. autoclass:: Task :members: -Spawning System Tasks -""""""""""""""""""""" - -.. automethod:: AsyncBackend.spawn_task - -.. autoclass:: SystemTask - :members: - Timeouts ^^^^^^^^ +.. automethod:: AsyncBackend.open_cancel_scope + .. automethod:: AsyncBackend.move_on_after .. automethod:: AsyncBackend.move_on_at @@ -97,7 +85,7 @@ Timeouts .. automethod:: AsyncBackend.timeout_at -.. autoclass:: TimeoutHandle +.. autoclass:: CancelScope :members: Networking @@ -135,9 +123,6 @@ Socket Adapter Classes .. autoclass:: AsyncStreamSocketAdapter :members: -.. autoclass:: AsyncHalfCloseableStreamSocketAdapter - :members: - .. autoclass:: AsyncDatagramSocketAdapter :members: @@ -187,6 +172,7 @@ Scheduling From Other Threads .. autoclass:: ThreadsPortal :members: + :special-members: __aenter__, __aexit__ ``concurrent.futures`` Integration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -213,16 +199,6 @@ Backend Factory :exclude-members: GROUP_NAME -Tasks Utilities -=============== - -.. automodule:: easynetwork.api_async.backend.tasks - :no-docstring: - -.. autoclass:: SingleTaskRunner - :members: - - Concurrency And Multithreading (``concurrent.futures`` Integration) =================================================================== diff --git a/pyproject.toml b/pyproject.toml index c4d468f2..96d958b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,6 @@ sniffio = [ "sniffio>=1.3.0", ] -[project.entry-points."easynetwork.async.backends"] -asyncio = "easynetwork_asyncio:AsyncioBackend" - ############################ pdm configuration ############################ [tool.pdm.dev-dependencies] diff --git a/src/easynetwork/api_async/backend/abc.py b/src/easynetwork/api_async/backend/abc.py index 6b8d34b4..53f77322 100644 --- a/src/easynetwork/api_async/backend/abc.py +++ b/src/easynetwork/api_async/backend/abc.py @@ -21,23 +21,23 @@ "AsyncBackend", "AsyncBaseSocketAdapter", "AsyncDatagramSocketAdapter", - "AsyncHalfCloseableStreamSocketAdapter", "AsyncListenerSocketAdapter", "AsyncStreamSocketAdapter", + "CancelScope", "ICondition", "IEvent", "ILock", "Task", "TaskGroup", "ThreadsPortal", - "TimeoutHandle", ] +import contextlib import contextvars import math from abc import ABCMeta, abstractmethod -from collections.abc import Callable, Coroutine, Iterable, Sequence -from contextlib import AbstractAsyncContextManager +from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Mapping, Sequence +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Any, Generic, NoReturn, ParamSpec, Protocol, Self, TypeVar if TYPE_CHECKING: @@ -187,55 +187,6 @@ async def wait(self) -> Any: # pragma: no cover ... -class Runner(metaclass=ABCMeta): - """ - A :term:`context manager` that simplifies `multiple` async function calls in the same context. - - Sometimes several top-level async functions should be called in the same event loop and :class:`contextvars.Context`. - """ - - __slots__ = ("__weakref__",) - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: - """Calls :meth:`close`.""" - self.close() - - @abstractmethod - def close(self) -> None: - """ - Closes the runner. - """ - raise NotImplementedError - - @abstractmethod - def run(self, coro_func: Callable[..., Coroutine[Any, Any, _T]], *args: Any) -> _T: - """ - Runs an async function, and returns the result. - - Calling:: - - runner.run(coro_func, *args) - - is equivalent to:: - - await coro_func(*args) - - except that :meth:`run` can (and must) be called from a synchronous context. - - Parameters: - coro_func: An async function. - args: Positional arguments to be passed to `coro_func`. If you need to pass keyword arguments, - then use :func:`functools.partial`. - - Returns: - Whatever `coro_func` returns. - """ - raise NotImplementedError - - class Task(Generic[_T_co], metaclass=ABCMeta): """ A :class:`Task` object represents a concurrent "thread" of execution. @@ -246,12 +197,12 @@ class Task(Generic[_T_co], metaclass=ABCMeta): @abstractmethod def done(self) -> bool: """ - Returns :data:`True` if the Task is done. + Returns the Task state. A Task is *done* when the wrapped coroutine either returned a value, raised an exception, or the Task was cancelled. Returns: - The Task state. + :data:`True` if the Task is done. """ raise NotImplementedError @@ -275,13 +226,13 @@ def cancel(self) -> bool: @abstractmethod def cancelled(self) -> bool: """ - Returns :data:`True` if the Task is *cancelled*. + Returns the cancellation state. The Task is *cancelled* when the cancellation was requested with :meth:`cancel` and the wrapped coroutine propagated the ``backend.get_cancelled_exc_class()`` exception thrown into it. Returns: - the cancellation state. + :data:`True` if the Task is *cancelled* """ raise NotImplementedError @@ -314,14 +265,6 @@ async def join(self) -> _T_co: """ raise NotImplementedError - -class SystemTask(Task[_T_co]): - """ - A :class:`SystemTask` is a :class:`Task` that runs concurrently with the current root task. - """ - - __slots__ = () - @abstractmethod async def join_or_cancel(self) -> _T_co: """ @@ -342,6 +285,104 @@ async def join_or_cancel(self) -> _T_co: raise NotImplementedError +class CancelScope(metaclass=ABCMeta): + """ + A temporary scope opened by a task that can be used by other tasks to control its execution time. + + Unlike trio's CancelScope, there is no "shielded" scopes; you must use :meth:`AsyncBackend.ignore_cancellation`. + """ + + __slots__ = ("__weakref__",) + + @abstractmethod + def __enter__(self) -> Self: + raise NotImplementedError + + @abstractmethod + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool: + raise NotImplementedError + + @abstractmethod + def cancel(self) -> None: + """ + Request the Task to be cancelled. + + This arranges for a ``backend.get_cancelled_exc_class()`` exception to be thrown into the wrapped coroutine + on the next cycle of the event loop. + + :meth:`CancelScope.cancel` does not guarantee that the Task will be cancelled, + although suppressing cancellation completely is not common and is actively discouraged. + """ + raise NotImplementedError + + @abstractmethod + def cancel_called(self) -> bool: + """ + Checks if :meth:`cancel` has been called. + + Returns: + :data:`True` if :meth:`cancel` has been called. + """ + raise NotImplementedError + + @abstractmethod + def cancelled_caught(self) -> bool: + """ + Returns the scope cancellation state. + + Returns: + :data:`True` if the scope has been is *cancelled*. + """ + raise NotImplementedError + + @abstractmethod + def when(self) -> float: + """ + Returns the current deadline. + + Returns: + the absolute time in seconds. :data:`math.inf` if the current deadline is not set. + """ + raise NotImplementedError + + @abstractmethod + def reschedule(self, when: float, /) -> None: + """ + Reschedules the timeout. + + Parameters: + when: The new deadline. + """ + raise NotImplementedError + + @property + def deadline(self) -> float: + """ + A read-write attribute to simplify the timeout management. + + For example, this statement:: + + scope.deadline += 30 + + is equivalent to:: + + scope.reschedule(scope.when() + 30) + + It is also possible to remove the timeout by deleting the attribute:: + + del scope.deadline + """ + return self.when() + + @deadline.setter + def deadline(self, value: float) -> None: + self.reschedule(value) + + @deadline.deleter + def deadline(self) -> None: + self.reschedule(math.inf) + + class TaskGroup(metaclass=ABCMeta): """ Groups several asynchronous tasks together. @@ -402,51 +443,121 @@ def start_soon( class ThreadsPortal(metaclass=ABCMeta): """ An object that lets external threads run code in an asynchronous event loop. + + You must use it as a context manager *within* the event loop to start the portal:: + + async with threads_portal: + ... + + If the portal is not entered or exited, then all of the operations would throw a :exc:`RuntimeError` for the threads. """ __slots__ = ("__weakref__",) @abstractmethod - def run_coroutine(self, coro_func: Callable[_P, Coroutine[Any, Any, _T]], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + async def __aenter__(self) -> Self: + raise NotImplementedError + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def run_coroutine_soon( + self, + coro_func: Callable[_P, Awaitable[_T]], + /, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> concurrent.futures.Future[_T]: + """ + Run the given async function in the bound event loop thread. Thread-safe. + + Parameters: + coro_func: An async function. + args: Positional arguments to be passed to `coro_func`. + kwargs: Keyword arguments to be passed to `coro_func`. + + Raises: + RuntimeError: if the portal is shut down. + RuntimeError: if you try calling this from inside the event loop thread, to avoid potential deadlocks. + + Returns: + A future filled with the result of ``await coro_func(*args, **kwargs)``. + """ + raise NotImplementedError + + def run_coroutine(self, coro_func: Callable[_P, Awaitable[_T]], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: """ Run the given async function in the bound event loop thread, blocking until it is complete. Thread-safe. + The default implementation is equivalent to:: + + portal.run_coroutine_soon(coro_func, *args, **kwargs).result() + Parameters: coro_func: An async function. args: Positional arguments to be passed to `coro_func`. kwargs: Keyword arguments to be passed to `coro_func`. Raises: - backend.get_cancelled_exc_class(): The scheduler was shut down while ``coro_func()`` was running + concurrent.futures.CancelledError: The portal has been shut down while ``coro_func()`` was running and cancelled the task. - RuntimeError: if the scheduler is shut down. + RuntimeError: if the portal is shut down. RuntimeError: if you try calling this from inside the event loop thread, which would otherwise cause a deadlock. - Exception: Whatever raises ``coro_func(*args, **kwargs)`` + Exception: Whatever raises ``await coro_func(*args, **kwargs)``. Returns: - Whatever returns ``coro_func(*args, **kwargs)`` + Whatever returns ``await coro_func(*args, **kwargs)``. """ - raise NotImplementedError + return self.run_coroutine_soon(coro_func, *args, **kwargs).result() @abstractmethod + def run_sync_soon(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> concurrent.futures.Future[_T]: + """ + Executes a function in the event loop thread from a worker thread. Thread-safe. + + Parameters: + func: A synchronous function. + args: Positional arguments to be passed to `func`. + kwargs: Keyword arguments to be passed to `func`. + + Raises: + RuntimeError: if the portal is shut down. + RuntimeError: if you try calling this from inside the event loop thread, to avoid potential deadlocks. + + Returns: + A future filled with the result of ``func(*args, **kwargs)``. + """ + raise NotImplementedError + def run_sync(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: """ Executes a function in the event loop thread from a worker thread. Thread-safe. + The default implementation is equivalent to:: + + portal.run_sync_soon(func, *args, **kwargs).result() + Parameters: func: A synchronous function. args: Positional arguments to be passed to `func`. kwargs: Keyword arguments to be passed to `func`. Raises: - RuntimeError: if the scheduler is shut down. + RuntimeError: if the portal is shut down. RuntimeError: if you try calling this from inside the event loop thread, which would otherwise cause a deadlock. - Exception: Whatever raises ``func(*args, **kwargs)`` + Exception: Whatever raises ``func(*args, **kwargs)``. Returns: - Whatever returns ``func(*args, **kwargs)`` + Whatever returns ``func(*args, **kwargs)``. """ - raise NotImplementedError + return self.run_sync_soon(func, *args, **kwargs).result() class AsyncBaseSocketAdapter(metaclass=ABCMeta): @@ -487,16 +598,6 @@ async def aclose(self) -> None: """ raise NotImplementedError - @abstractmethod - def get_local_address(self) -> tuple[Any, ...]: - """ - Returns the local socket address. Roughly similar to :meth:`socket.socket.getsockname`. - - Returns: - The socket address. - """ - raise NotImplementedError - @abstractmethod def socket(self) -> ISocket: """ @@ -515,16 +616,6 @@ class AsyncStreamSocketAdapter(AsyncBaseSocketAdapter): __slots__ = () - @abstractmethod - def get_remote_address(self) -> tuple[Any, ...]: - """ - Returns the remote endpoint's address. Roughly similar to :meth:`socket.socket.getpeername`. - - Returns: - The remote address. - """ - raise NotImplementedError - @abstractmethod async def recv(self, bufsize: int, /) -> bytes: """ @@ -557,14 +648,6 @@ async def sendall_fromiter(self, iterable_of_data: Iterable[bytes], /) -> None: """ await self.sendall(b"".join(iterable_of_data)) - -class AsyncHalfCloseableStreamSocketAdapter(AsyncStreamSocketAdapter): - """ - A stream-oriented socket interface that also supports closing only the write end of the stream. - """ - - __slots__ = () - @abstractmethod async def send_eof(self) -> None: """ @@ -580,16 +663,6 @@ class AsyncDatagramSocketAdapter(AsyncBaseSocketAdapter): __slots__ = () - @abstractmethod - def get_remote_address(self) -> tuple[Any, ...] | None: - """ - Returns the remote endpoint's address. Roughly similar to :meth:`socket.socket.getpeername`. - - Returns: - The remote address if configured, :data:`None` otherwise. - """ - raise NotImplementedError - @abstractmethod async def recvfrom(self, bufsize: int, /) -> tuple[bytes, tuple[Any, ...]]: """ @@ -682,114 +755,54 @@ async def connect(self) -> AsyncStreamSocketAdapter: raise NotImplementedError -class TimeoutHandle(metaclass=ABCMeta): +class AsyncBackend(metaclass=ABCMeta): """ - Interface to deal with an actual timeout scope. + Asynchronous backend interface. - See :meth:`AsyncBackend.move_on_after` for details. + It bridges the gap between asynchronous frameworks (``asyncio``, ``trio``, or whatever) and EasyNetwork. """ - __slots__ = () - - @abstractmethod - def when(self) -> float: - """ - Returns the current deadline. - - Returns: - the absolute time in seconds. :data:`math.inf` if the current deadline is not set. - A negative value can be returned. - """ - raise NotImplementedError - - @abstractmethod - def reschedule(self, when: float, /) -> None: - """ - Reschedules the timeout. - - Parameters: - when: The new deadline. - """ - raise NotImplementedError + __slots__ = ("__weakref__",) @abstractmethod - def expired(self) -> bool: - """ - Returns whether the context manager has exceeded its deadline (expired). - - Returns: - the timeout state. - """ - raise NotImplementedError - - @property - def deadline(self) -> float: + def bootstrap( + self, + coro_func: Callable[..., Coroutine[Any, Any, _T]], + *args: Any, + runner_options: Mapping[str, Any] | None = ..., + ) -> _T: """ - A read-write attribute to simplify the timeout management. + Runs an async function, and returns the result. - For example, this statement:: + Calling:: - handle.deadline += 30 + backend.bootstrap(coro_func, *args) is equivalent to:: - handle.reschedule(handle.when() + 30) - - It is also possible to remove the timeout by deleting the attribute:: - - del handle.deadline - """ - return self.when() - - @deadline.setter - def deadline(self, value: float) -> None: - self.reschedule(value) - - @deadline.deleter - def deadline(self) -> None: - self.reschedule(math.inf) - - -class AsyncBackend(metaclass=ABCMeta): - """ - Asynchronous backend interface. - - It bridges the gap between asynchronous frameworks (``asyncio``, ``trio``, or whatever) and EasyNetwork. - """ - - __slots__ = ("__weakref__",) - - @abstractmethod - def new_runner(self) -> Runner: - """ - Returns an asynchronous function runner. + await coro_func(*args) - Returns: - A :class:`Runner` context. - """ - raise NotImplementedError + except that :meth:`bootstrap` can (and must) be called from a synchronous context. - def bootstrap(self, coro_func: Callable[..., Coroutine[Any, Any, _T]], *args: Any) -> _T: - """ - Runs an async function, and returns the result. + `runner_options` can be used to give additional parameters to the backend runner. For example:: - Equivalent to:: + backend.bootstrap(coro_func, *args, runner_options={"loop_factory": uvloop.new_event_loop}) - with backend.new_runner() as runner: - return runner.run(coro_func, *args) + would act as the following for :mod:`asyncio`:: - See :meth:`Runner.run` documentation for details. + with asyncio.Runner(loop_factory=uvloop.new_event_loop): + runner.run(coro_func(*args)) Parameters: coro_func: An async function. args: Positional arguments to be passed to `coro_func`. If you need to pass keyword arguments, then use :func:`functools.partial`. + runner_options: Options for backend's runner. Returns: - Whatever `coro_func` returns. + Whatever ``await coro_func(*args)`` returns. """ - with self.new_runner() as runner: - return runner.run(coro_func, *args) + raise NotImplementedError @abstractmethod async def coro_yield(self) -> None: @@ -847,9 +860,21 @@ async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, _T_co]) -> _T raise NotImplementedError @abstractmethod - def timeout(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def open_cancel_scope(self, *, deadline: float = ...) -> CancelScope: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Open a new cancel scope. See :meth:`move_on_after` for details. + + Parameters: + deadline: absolute time to stop waiting. Defaults to :data:`math.inf`. + + Returns: + a new cancel scope. + """ + raise NotImplementedError + + def timeout(self, delay: float) -> AbstractContextManager[CancelScope]: + """ + Returns a :term:`context manager` that can be used to limit the amount of time spent waiting on something. This function and :meth:`move_on_after` are similar in that both create a context manager with a given timeout, and if the timeout expires then both will cause ``backend.get_cancelled_exc_class()`` to be raised within the scope. @@ -860,14 +885,13 @@ def timeout(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: delay: number of seconds to wait. Returns: - an :term:`asynchronous context manager` + a :term:`context manager` """ - raise NotImplementedError + return _timeout_after(self, delay) - @abstractmethod - def timeout_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def timeout_at(self, deadline: float) -> AbstractContextManager[CancelScope]: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Returns a :term:`context manager` that can be used to limit the amount of time spent waiting on something. This function and :meth:`move_on_at` are similar in that both create a context manager with a given timeout, and if the timeout expires then both will cause ``backend.get_cancelled_exc_class()`` to be raised within the scope. @@ -878,14 +902,13 @@ def timeout_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHand deadline: absolute time to stop waiting. Returns: - an :term:`asynchronous context manager` + a :term:`context manager` """ - raise NotImplementedError + return _timeout_at(self, deadline) - @abstractmethod - def move_on_after(self, delay: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def move_on_after(self, delay: float) -> CancelScope: """ - Returns an :term:`asynchronous context manager` that can be used to limit the amount of time spent waiting on something. + Returns a new :class:`CancelScope` that can be used to limit the amount of time spent waiting on something. The deadline is set to now + `delay`. Example:: @@ -896,7 +919,7 @@ async def long_running_operation(backend): async def main(): ... - async with backend.move_on_after(10): + with backend.move_on_after(10): await long_running_operation(backend) print("After at most 10 seconds.") @@ -907,15 +930,14 @@ async def main(): Parameters: delay: number of seconds to wait. If `delay` is :data:`math.inf`, no time limit will be applied; this can be useful if the delay is unknown when the context manager is created. - In either case, the context manager can be rescheduled after creation using :meth:`TimeoutHandle.reschedule`. + In either case, the context manager can be rescheduled after creation using :meth:`CancelScope.reschedule`. Returns: - an :term:`asynchronous context manager` + a new cancel scope. """ - raise NotImplementedError + return self.open_cancel_scope(deadline=self.current_time() + delay) - @abstractmethod - def move_on_at(self, deadline: float) -> AbstractAsyncContextManager[TimeoutHandle]: + def move_on_at(self, deadline: float) -> CancelScope: """ Similar to :meth:`move_on_after`, except `deadline` is the absolute time to stop waiting, or :data:`math.inf`. @@ -928,7 +950,7 @@ async def main(): ... deadline = backend.current_time() + 10 - async with backend.move_on_at(deadline): + with backend.move_on_at(deadline): await long_running_operation(backend) print("After at most 10 seconds.") @@ -937,9 +959,9 @@ async def main(): deadline: absolute time to stop waiting. Returns: - an :term:`asynchronous context manager` + a new cancel scope. """ - raise NotImplementedError + return self.open_cancel_scope(deadline=deadline) @abstractmethod def current_time(self) -> float: @@ -988,31 +1010,6 @@ async def sleep_until(self, deadline: float) -> None: """ return await self.sleep(max(deadline - self.current_time(), 0)) - @abstractmethod - def spawn_task( - self, - coro_func: Callable[..., Coroutine[Any, Any, _T]], - /, - *args: Any, - context: contextvars.Context | None = ..., - ) -> SystemTask[_T]: - """ - Starts a new "system" task. - - It is a background task that runs concurrently with the current root task. - - Parameters: - coro_func: An async function. - args: Positional arguments to be passed to `coro_func`. If you need to pass keyword arguments, - then use :func:`functools.partial`. - context: If given, it must be a :class:`contextvars.Context` instance in which the coroutine should be executed. - If the framework does not support contexts (or does not use them), it must simply ignore this parameter. - - Returns: - the created task. - """ - raise NotImplementedError - @abstractmethod def create_task_group(self) -> TaskGroup: """ @@ -1383,3 +1380,16 @@ async def wait_future(self, future: concurrent.futures.Future[_T_co]) -> _T_co: Whatever returns ``future.result()`` """ raise NotImplementedError + + +def _timeout_after(backend: AsyncBackend, delay: float) -> contextlib._GeneratorContextManager[CancelScope]: + return _timeout_at(backend, backend.current_time() + delay) + + +@contextlib.contextmanager +def _timeout_at(backend: AsyncBackend, deadline: float) -> Iterator[CancelScope]: + with backend.move_on_at(deadline) as scope: + yield scope + + if scope.cancelled_caught(): + raise TimeoutError("timed out") diff --git a/src/easynetwork/api_async/backend/factory.py b/src/easynetwork/api_async/backend/factory.py index 5130aa08..f14a3cb9 100644 --- a/src/easynetwork/api_async/backend/factory.py +++ b/src/easynetwork/api_async/backend/factory.py @@ -151,7 +151,7 @@ def __load_backend_cls_from_entry_point(name: str) -> type[AsyncBackend]: @staticmethod @functools.cache def __get_available_backends() -> MappingProxyType[str, EntryPoint]: - from importlib.metadata import entry_points as get_all_entry_points + from importlib.metadata import EntryPoint, entry_points as get_all_entry_points entry_points = get_all_entry_points(group=AsyncBackendFactory.GROUP_NAME) duplicate_counter: Counter[str] = Counter([ep.name for ep in entry_points]) @@ -161,6 +161,11 @@ def __get_available_backends() -> MappingProxyType[str, EntryPoint]: backends: dict[str, EntryPoint] = {ep.name: ep for ep in entry_points} - assert "asyncio" in backends, "SystemError: Missing 'asyncio' entry point." # nosec assert_used + if "asyncio" not in backends: + backends["asyncio"] = EntryPoint( + name="asyncio", + value="easynetwork_asyncio:AsyncIOBackend", + group=AsyncBackendFactory.GROUP_NAME, + ) return MappingProxyType(backends) diff --git a/src/easynetwork/api_async/backend/futures.py b/src/easynetwork/api_async/backend/futures.py index 18f4dea4..d0ea58ea 100644 --- a/src/easynetwork/api_async/backend/futures.py +++ b/src/easynetwork/api_async/backend/futures.py @@ -21,7 +21,8 @@ import concurrent.futures import contextvars import functools -from collections.abc import Callable, Mapping +from collections import deque +from collections.abc import AsyncGenerator, Callable, Iterable, Mapping from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar from .factory import AsyncBackendFactory @@ -132,7 +133,44 @@ async def run(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwar func = self._setup_func(func) executor = self.__executor backend = self.__backend - return await backend.wait_future(executor.submit(func, *args, **kwargs)) + return await _result_or_cancel(backend, executor.submit(func, *args, **kwargs)) + + def map(self, func: Callable[..., _T], *iterables: Iterable[Any]) -> AsyncGenerator[_T, None]: + """ + Returns an asynchronous iterator equivalent to ``map(fn, iter)``. + + Example:: + + def pow_50(x): + return x**50 + + async with AsyncExecutor(ProcessPoolExecutor()) as executor: + results = [result async for result in executor.map(pow_50, (1, 4, 12))] + + Parameters: + func: A callable that will take as many arguments as there are passed `iterables`. + iterables: iterables yielding arguments for `func`. + + Raises: + Exception: If ``fn(*args)`` raises for any values. + + Returns: + An asynchronous iterator equivalent to ``map(func, *iterables)`` but the calls may be evaluated out-of-order. + """ + + backend = self.__backend + executor = self.__executor + fs = deque(executor.submit(self._setup_func(func), *args) for args in zip(*iterables)) + + async def result_iterator() -> AsyncGenerator[_T, None]: + try: + while fs: + yield await _result_or_cancel(backend, fs.popleft()) + finally: + for future in fs: + future.cancel() + + return result_iterator() def shutdown_nowait(self, *, cancel_futures: bool = False) -> None: """ @@ -174,3 +212,14 @@ def _setup_func(self, func: Callable[_P, _T]) -> Callable[_P, _T]: func = functools.partial(ctx.run, func) # type: ignore[assignment] return func + + +async def _result_or_cancel(backend: AsyncBackend, future: concurrent.futures.Future[_T]) -> _T: + try: + try: + return await backend.wait_future(future) + finally: + future.cancel() + finally: + # Break a reference cycle with the exception in future._exception + del future diff --git a/src/easynetwork/api_async/backend/tasks.py b/src/easynetwork/api_async/backend/tasks.py deleted file mode 100644 index 4cb513a8..00000000 --- a/src/easynetwork/api_async/backend/tasks.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -"""Task utilities module""" - -from __future__ import annotations - -__all__ = ["SingleTaskRunner"] - -import functools -from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar - -if TYPE_CHECKING: - from .abc import AsyncBackend, SystemTask - - -_P = ParamSpec("_P") -_T_co = TypeVar("_T_co", covariant=True) - - -class SingleTaskRunner(Generic[_T_co]): - """ - An helper class to execute a coroutine function only once. - - In addition to one-time execution, concurrent calls will simply wait for the result:: - - async def expensive_task(): - print("Start expensive task") - - ... - - print("Done") - return 42 - - async def main(): - ... - - task_runner = SingleTaskRunner(backend, expensive_task) - async with backend.create_task_group() as task_group: - tasks = [task_group.start_soon(task_runner.run) for _ in range(10)] - - assert all(await t.join() == 42 for t in tasks) - """ - - __slots__ = ( - "__backend", - "__coro_func", - "__task", - "__weakref__", - ) - - def __init__( - self, - backend: AsyncBackend, - coro_func: Callable[_P, Coroutine[Any, Any, _T_co]], - /, - *args: _P.args, - **kwargs: _P.kwargs, - ) -> None: - """ - Parameters: - backend: The asynchronous backend interface. - coro_func: An async function. - args: Positional arguments to be passed to `coro_func`. - kwargs: Keyword arguments to be passed to `coro_func`. - """ - super().__init__() - - self.__backend: AsyncBackend = backend - self.__coro_func: Callable[[], Coroutine[Any, Any, _T_co]] | None = functools.partial( - coro_func, - *args, - **kwargs, - ) - self.__task: SystemTask[_T_co] | None = None - - def cancel(self) -> bool: - """ - Cancel coroutine execution. - - If the runner was not used yet, :meth:`run` will not call `coro_func` and raise ``backend.get_cancelled_exc_class()``. - - If `coro_func` is already running, a cancellation request is sent to the coroutine. - - Returns: - :data:`True` in case of success, :data:`False` otherwise. - """ - self.__coro_func = None - if self.__task is not None: - return self.__task.cancel() - return True - - async def run(self) -> _T_co: - """ - Executes the coroutine `coro_func`. - - Raises: - Exception: Whatever ``coro_func`` raises. - - Returns: - Whatever ``coro_func`` returns. - """ - must_cancel_inner_task: bool = False - if self.__task is None: - must_cancel_inner_task = True - if self.__coro_func is None: - self.__task = self.__backend.spawn_task(self.__backend.sleep_forever) - self.__task.cancel() - else: - coro_func = self.__coro_func - self.__coro_func = None - self.__task = self.__backend.spawn_task(coro_func) - del coro_func - - try: - if must_cancel_inner_task: - return await self.__task.join_or_cancel() - else: - return await self.__task.join() - finally: - del self # Avoid circular reference with raised exception (if any) diff --git a/src/easynetwork/api_async/client/abc.py b/src/easynetwork/api_async/client/abc.py index 2a7b79f6..7b291f58 100644 --- a/src/easynetwork/api_async/client/abc.py +++ b/src/easynetwork/api_async/client/abc.py @@ -228,7 +228,7 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter while True: try: - async with timeout_after(timeout): + with timeout_after(timeout): _start = perf_counter() packet = await self.recv_packet() _end = perf_counter() diff --git a/src/easynetwork/api_async/client/tcp.py b/src/easynetwork/api_async/client/tcp.py index 5d3dfdf3..e7a019c0 100644 --- a/src/easynetwork/api_async/client/tcp.py +++ b/src/easynetwork/api_async/client/tcp.py @@ -19,10 +19,11 @@ __all__ = ["AsyncTCPNetworkClient"] import contextlib as _contextlib +import dataclasses as _dataclasses import errno as _errno import socket as _socket -from collections.abc import Callable, Iterator, Mapping -from typing import TYPE_CHECKING, Any, NoReturn, TypedDict, final, overload +from collections.abc import Awaitable, Callable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, NoReturn, final, overload try: import ssl as _ssl @@ -41,24 +42,41 @@ check_socket_family as _check_socket_family, check_socket_no_ssl as _check_socket_no_ssl, error_from_errno as _error_from_errno, + make_callback as _make_callback, ) from ...tools.constants import CLOSED_SOCKET_ERRNOS, MAX_STREAM_BUFSIZE, SSL_HANDSHAKE_TIMEOUT, SSL_SHUTDOWN_TIMEOUT from ...tools.socket import SocketAddress, SocketProxy, new_socket_address, set_tcp_keepalive, set_tcp_nodelay -from ..backend.abc import AsyncBackend, AsyncHalfCloseableStreamSocketAdapter, AsyncStreamSocketAdapter, ILock +from ..backend.abc import AsyncBackend, AsyncStreamSocketAdapter, CancelScope, ILock from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from .abc import AbstractAsyncNetworkClient if TYPE_CHECKING: import ssl as _typing_ssl -class _ClientInfo(TypedDict): +@_dataclasses.dataclass(kw_only=True, frozen=True, slots=True) +class _ClientInfo: proxy: SocketProxy local_address: SocketAddress remote_address: SocketAddress +@_dataclasses.dataclass(kw_only=True, slots=True) +class _SocketConnector: + lock: ILock + factory: Callable[[], Awaitable[tuple[AsyncStreamSocketAdapter, _ClientInfo]]] | None + scope: CancelScope + _result: tuple[AsyncStreamSocketAdapter, _ClientInfo] | None = _dataclasses.field(init=False, default=None) + + async def get(self) -> tuple[AsyncStreamSocketAdapter, _ClientInfo] | None: + async with self.lock: + factory, self.factory = self.factory, None + if factory is not None: + with self.scope: + self._result = await factory() + return self._result + + class AsyncTCPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPacketT]): """ An asynchronous network client interface for TCP connections. @@ -217,7 +235,7 @@ def __init__( def _value_or_default(value: float | None, default: float) -> float: return value if value is not None else default - self.__socket_connector: SingleTaskRunner[AsyncStreamSocketAdapter] | None = None + socket_factory: Callable[[], Awaitable[AsyncStreamSocketAdapter]] match __arg: case _socket.socket() as socket: _check_socket_family(socket.family) @@ -225,8 +243,7 @@ def _value_or_default(value: float | None, default: float) -> float: if ssl: if server_hostname is None: raise ValueError("You must set server_hostname when using ssl without a host") - self.__socket_connector = SingleTaskRunner( - backend, + socket_factory = _make_callback( backend.wrap_ssl_over_tcp_client_socket, socket, ssl_context=ssl, @@ -236,11 +253,10 @@ def _value_or_default(value: float | None, default: float) -> float: **kwargs, ) else: - self.__socket_connector = SingleTaskRunner(backend, backend.wrap_tcp_client_socket, socket, **kwargs) + socket_factory = _make_callback(backend.wrap_tcp_client_socket, socket, **kwargs) case (host, port): if ssl: - self.__socket_connector = SingleTaskRunner( - backend, + socket_factory = _make_callback( backend.create_ssl_over_tcp_connection, host, port, @@ -251,11 +267,16 @@ def _value_or_default(value: float | None, default: float) -> float: **kwargs, ) else: - self.__socket_connector = SingleTaskRunner(backend, backend.create_tcp_connection, host, port, **kwargs) + socket_factory = _make_callback(backend.create_tcp_connection, host, port, **kwargs) case _: # pragma: no cover raise TypeError("Invalid arguments") - assert self.__socket_connector is not None # nosec assert_used + self.__socket_connector: _SocketConnector | None = _SocketConnector( + lock=self.__backend.create_lock(), + factory=_make_callback(self.__create_socket, socket_factory), + scope=self.__backend.open_cancel_scope(), + ) + assert ssl_shared_lock is not None # nosec assert_used self.__receive_lock: ILock @@ -312,33 +333,26 @@ async def wait_connected(self) -> None: ConnectionError: could not connect to remote. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - if self.__socket is None: - socket_connector = self.__socket_connector - if socket_connector is None: - raise ClientClosedError("Client is closing, or is already closed") - socket = await socket_connector.run() - if self.__socket_connector is None: # wait_connected() or aclose() called in concurrency - return await self.__backend.cancel_shielded_coro_yield() - self.__socket = socket - self.__socket_connector = None - if self.__info is None: - self.__info = self.__build_info_dict(self.__socket) - socket_proxy = self.__info["proxy"] - with _contextlib.suppress(OSError): - set_tcp_nodelay(socket_proxy, True) - with _contextlib.suppress(OSError): - set_tcp_keepalive(socket_proxy, True) + await self.__ensure_connected() @staticmethod - def __build_info_dict(socket: AsyncStreamSocketAdapter) -> _ClientInfo: + async def __create_socket( + socket_factory: Callable[[], Awaitable[AsyncStreamSocketAdapter]], + ) -> tuple[AsyncStreamSocketAdapter, _ClientInfo]: + socket = await socket_factory() socket_proxy = SocketProxy(socket.socket()) - local_address: SocketAddress = new_socket_address(socket.get_local_address(), socket_proxy.family) - remote_address: SocketAddress = new_socket_address(socket.get_remote_address(), socket_proxy.family) - return { - "proxy": socket_proxy, - "local_address": local_address, - "remote_address": remote_address, - } + local_address: SocketAddress = new_socket_address(socket_proxy.getsockname(), socket_proxy.family) + remote_address: SocketAddress = new_socket_address(socket_proxy.getpeername(), socket_proxy.family) + info = _ClientInfo( + proxy=socket_proxy, + local_address=local_address, + remote_address=remote_address, + ) + with _contextlib.suppress(OSError): + set_tcp_nodelay(socket_proxy, True) + with _contextlib.suppress(OSError): + set_tcp_keepalive(socket_proxy, True) + return socket, info def is_closing(self) -> bool: """ @@ -372,7 +386,7 @@ async def aclose(self) -> None: Can be safely called multiple times. """ if self.__socket_connector is not None: - self.__socket_connector.cancel() + self.__socket_connector.scope.cancel() self.__socket_connector = None async with self.__send_lock: socket, self.__socket = self.__socket, None @@ -389,8 +403,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: """ Sends `packet` to the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Warning: In the case of a cancellation, it is impossible to know if all the packet data has been sent. This would leave the connection in an inconsistent state. @@ -406,7 +418,7 @@ async def send_packet(self, packet: _SentPacketT) -> None: RuntimeError: :meth:`send_eof` has been called earlier. """ async with self.__send_lock: - socket = await self.__ensure_connected(check_socket_is_closing=True) + socket = await self.__ensure_connected() if self.__eof_sent: raise RuntimeError("send_eof() has been called earlier") with self.__convert_socket_error(): @@ -423,26 +435,20 @@ async def send_eof(self) -> None: ClientClosedError: the client object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - try: - socket = await self.__ensure_connected(check_socket_is_closing=False) - except ConnectionError: - return - if not isinstance(socket, AsyncHalfCloseableStreamSocketAdapter): - raise NotImplementedError - async with self.__send_lock: if self.__eof_sent: return + try: + socket = await self.__ensure_connected() + except ConnectionError: + return + await socket.send_eof() self.__eof_sent = True - if not socket.is_closing(): - await socket.send_eof() async def recv_packet(self) -> _ReceivedPacketT: """ Waits for a new packet to arrive from the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Raises: ClientClosedError: the client object is closed. ConnectionError: connection unexpectedly closed during operation. @@ -460,7 +466,7 @@ async def recv_packet(self) -> _ReceivedPacketT: except StopIteration: pass - socket = await self.__ensure_connected(check_socket_is_closing=True) + socket = await self.__ensure_connected() if self.__eof_reached: self.__abort(None) @@ -496,7 +502,7 @@ def get_local_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["local_address"] + return self.__info.local_address def get_remote_address(self) -> SocketAddress: """ @@ -513,20 +519,26 @@ def get_remote_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["remote_address"] + return self.__info.remote_address def get_backend(self) -> AsyncBackend: return self.__backend get_backend.__doc__ = AbstractAsyncNetworkClient.get_backend.__doc__ - async def __ensure_connected(self, *, check_socket_is_closing: bool) -> AsyncStreamSocketAdapter: - await self.wait_connected() - assert self.__socket is not None # nosec assert_used - socket = self.__socket - if check_socket_is_closing and socket.is_closing(): + async def __ensure_connected(self) -> AsyncStreamSocketAdapter: + if self.__socket is None or self.__info is None: + socket_and_info = None + if (socket_connector := self.__socket_connector) is not None: + socket_and_info = await socket_connector.get() + self.__socket_connector = None + if socket_and_info is None: + raise ClientClosedError("Client is closing, or is already closed") + self.__socket, self.__info = socket_and_info + + if self.__socket.is_closing(): self.__abort(None) - return socket + return self.__socket @_contextlib.contextmanager def __convert_socket_error(self) -> Iterator[None]: @@ -554,7 +566,7 @@ def socket(self) -> SocketProxy: """ if self.__info is None: raise AttributeError("Socket not connected") - return self.__info["proxy"] + return self.__info.proxy @property @final diff --git a/src/easynetwork/api_async/client/udp.py b/src/easynetwork/api_async/client/udp.py index ed484728..2c15644b 100644 --- a/src/easynetwork/api_async/client/udp.py +++ b/src/easynetwork/api_async/client/udp.py @@ -19,12 +19,13 @@ __all__ = ["AsyncUDPNetworkClient", "AsyncUDPNetworkEndpoint"] import contextlib +import dataclasses as _dataclasses import errno as _errno import math import socket as _socket import time -from collections.abc import AsyncGenerator, AsyncIterator, Mapping -from typing import TYPE_CHECKING, Any, Generic, Self, TypedDict, final, overload +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping +from typing import TYPE_CHECKING, Any, Generic, Self, final, overload from ..._typevars import _ReceivedPacketT, _SentPacketT from ...exceptions import ClientClosedError, DatagramProtocolParseError @@ -35,24 +36,41 @@ check_socket_no_ssl as _check_socket_no_ssl, ensure_datagram_socket_bound as _ensure_datagram_socket_bound, error_from_errno as _error_from_errno, + make_callback as _make_callback, ) from ...tools.constants import MAX_DATAGRAM_BUFSIZE from ...tools.socket import SocketAddress, SocketProxy, new_socket_address -from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, ILock +from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, CancelScope, ILock from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from .abc import AbstractAsyncNetworkClient if TYPE_CHECKING: from types import TracebackType -class _EndpointInfo(TypedDict): +@_dataclasses.dataclass(kw_only=True, frozen=True, slots=True) +class _EndpointInfo: proxy: SocketProxy local_address: SocketAddress remote_address: SocketAddress | None +@_dataclasses.dataclass(kw_only=True, slots=True) +class _SocketConnector: + lock: ILock + factory: Callable[[], Awaitable[tuple[AsyncDatagramSocketAdapter, _EndpointInfo]]] | None + scope: CancelScope + _result: tuple[AsyncDatagramSocketAdapter, _EndpointInfo] | None = _dataclasses.field(init=False, default=None) + + async def get(self) -> tuple[AsyncDatagramSocketAdapter, _EndpointInfo] | None: + async with self.lock: + factory, self.factory = self.factory, None + if factory is not None: + with self.scope: + self._result = await factory() + return self._result + + class AsyncUDPNetworkEndpoint(Generic[_SentPacketT, _ReceivedPacketT]): """Asynchronous generic UDP endpoint interface. @@ -64,7 +82,7 @@ class AsyncUDPNetworkEndpoint(Generic[_SentPacketT, _ReceivedPacketT]): __slots__ = ( "__socket", "__backend", - "__socket_builder", + "__socket_connector", "__info", "__receive_lock", "__send_lock", @@ -136,18 +154,23 @@ def __init__( self.__backend: AsyncBackend = backend self.__info: _EndpointInfo | None = None - self.__socket_builder: SingleTaskRunner[AsyncDatagramSocketAdapter] | None = None + socket_factory: Callable[[], Awaitable[AsyncDatagramSocketAdapter]] match kwargs: case {"socket": _socket.socket() as socket, **kwargs}: _check_socket_family(socket.family) _check_socket_no_ssl(socket) _ensure_datagram_socket_bound(socket) - self.__socket_builder = SingleTaskRunner(backend, backend.wrap_udp_socket, socket, **kwargs) + socket_factory = _make_callback(backend.wrap_udp_socket, socket, **kwargs) case _: if kwargs.get("local_address") is None: kwargs["local_address"] = ("localhost", 0) - self.__socket_builder = SingleTaskRunner(backend, backend.create_udp_endpoint, **kwargs) + socket_factory = _make_callback(backend.create_udp_endpoint, **kwargs) + self.__socket_connector: _SocketConnector | None = _SocketConnector( + lock=self.__backend.create_lock(), + factory=_make_callback(self.__create_socket, socket_factory), + scope=self.__backend.open_cancel_scope(), + ) self.__receive_lock: ILock = backend.create_lock() self.__send_lock: ILock = backend.create_lock() @@ -211,34 +234,28 @@ async def wait_bound(self) -> None: ClientClosedError: the client object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - if self.__socket is None: - socket_builder = self.__socket_builder - if socket_builder is None: - raise ClientClosedError("Client is closing, or is already closed") - socket = await socket_builder.run() - if self.__socket_builder is None: # wait_bound() or aclose() called in concurrency - return await self.__backend.cancel_shielded_coro_yield() - self.__socket = socket - self.__socket_builder = None - if self.__info is None: - self.__info = self.__build_info_dict(self.__socket) + await self.__ensure_opened() @staticmethod - def __build_info_dict(socket: AsyncDatagramSocketAdapter) -> _EndpointInfo: + async def __create_socket( + socket_factory: Callable[[], Awaitable[AsyncDatagramSocketAdapter]], + ) -> tuple[AsyncDatagramSocketAdapter, _EndpointInfo]: + socket = await socket_factory() socket_proxy = SocketProxy(socket.socket()) - local_address: SocketAddress = new_socket_address(socket.get_local_address(), socket_proxy.family) + local_address: SocketAddress = new_socket_address(socket_proxy.getsockname(), socket_proxy.family) if local_address.port == 0: raise AssertionError(f"{socket} is not bound to a local address") remote_address: SocketAddress | None - if (peername := socket.get_remote_address()) is None: + try: + remote_address = new_socket_address(socket_proxy.getpeername(), socket_proxy.family) + except OSError: remote_address = None - else: - remote_address = new_socket_address(peername, socket_proxy.family) - return { - "proxy": socket_proxy, - "local_address": local_address, - "remote_address": remote_address, - } + info = _EndpointInfo( + proxy=socket_proxy, + local_address=local_address, + remote_address=remote_address, + ) + return socket, info def is_closing(self) -> bool: """ @@ -252,7 +269,7 @@ def is_closing(self) -> bool: Returns: the endpoint state. """ - if self.__socket_builder is not None: + if self.__socket_connector is not None: return False socket = self.__socket return socket is None or socket.is_closing() @@ -270,9 +287,9 @@ async def aclose(self) -> None: Can be safely called multiple times. """ - if self.__socket_builder is not None: - self.__socket_builder.cancel() - self.__socket_builder = None + if self.__socket_connector is not None: + self.__socket_connector.scope.cancel() + self.__socket_connector = None async with self.__send_lock: socket, self.__socket = self.__socket, None if socket is None: @@ -292,8 +309,6 @@ async def send_packet_to( """ Sends `packet` to the remote endpoint `address`. Does not require task synchronization. - Calls :meth:`wait_bound`. - If a remote address is configured, `address` must be :data:`None` or the same as the remote address, otherwise `address` must not be :data:`None`. @@ -311,7 +326,7 @@ async def send_packet_to( async with self.__send_lock: socket = await self.__ensure_opened() assert self.__info is not None # nosec assert_used - if (remote_addr := self.__info["remote_address"]) is not None: + if (remote_addr := self.__info.remote_address) is not None: if address is not None: if new_socket_address(address, self.socket.family) != remote_addr: raise ValueError(f"Invalid address: must be None or {remote_addr}") @@ -326,8 +341,6 @@ async def recv_packet_from(self) -> tuple[_ReceivedPacketT, SocketAddress]: """ Waits for a new packet to arrive from another endpoint. Does not require task synchronization. - Calls :meth:`wait_bound`. - Raises: ClientClosedError: the endpoint object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. @@ -386,7 +399,7 @@ async def iter_received_packets_from( while True: try: - async with timeout_after(timeout): + with timeout_after(timeout): _start = perf_counter() packet_tuple = await self.recv_packet_from() _end = perf_counter() @@ -409,7 +422,7 @@ def get_local_address(self) -> SocketAddress: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["local_address"] + return self.__info.local_address def get_remote_address(self) -> SocketAddress | None: """ @@ -424,14 +437,20 @@ def get_remote_address(self) -> SocketAddress | None: """ if self.__info is None: raise _error_from_errno(_errno.ENOTSOCK) - return self.__info["remote_address"] + return self.__info.remote_address def get_backend(self) -> AsyncBackend: return self.__backend async def __ensure_opened(self) -> AsyncDatagramSocketAdapter: - await self.wait_bound() - assert self.__socket is not None # nosec assert_used + if self.__socket is None or self.__info is None: + socket_and_info = None + if (socket_connector := self.__socket_connector) is not None: + socket_and_info = await socket_connector.get() + self.__socket_connector = None + if socket_and_info is None: + raise ClientClosedError("Client is closing, or is already closed") + self.__socket, self.__info = socket_and_info if self.__socket.is_closing(): raise _error_from_errno(_errno.ECONNABORTED) return self.__socket @@ -445,7 +464,7 @@ def socket(self) -> SocketProxy: """ if self.__info is None: raise AttributeError("Socket not connected") - return self.__info["proxy"] + return self.__info.proxy class AsyncUDPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPacketT], Generic[_SentPacketT, _ReceivedPacketT]): @@ -597,8 +616,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: """ Sends `packet` to the remote endpoint. Does not require task synchronization. - Calls :meth:`wait_connected`. - Warning: In the case of a cancellation, it is impossible to know if all the packet data has been sent. @@ -609,7 +626,6 @@ async def send_packet(self, packet: _SentPacketT) -> None: ClientClosedError: the client object is closed. OSError: unrelated OS error occurred. You should check :attr:`OSError.errno`. """ - await self.wait_connected() return await self.__endpoint.send_packet_to(packet, None) async def recv_packet(self) -> _ReceivedPacketT: @@ -626,12 +642,10 @@ async def recv_packet(self) -> _ReceivedPacketT: Returns: the received packet. """ - await self.wait_connected() packet, _ = await self.__endpoint.recv_packet_from() return packet - async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIterator[_ReceivedPacketT]: - await self.wait_connected() + async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncGenerator[_ReceivedPacketT, None]: async with contextlib.aclosing(self.__endpoint.iter_received_packets_from(timeout=timeout)) as generator: async for packet, _ in generator: yield packet diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index ef2341c8..a77acf87 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -53,9 +53,7 @@ set_tcp_keepalive, set_tcp_nodelay, ) -from ..backend.abc import AsyncHalfCloseableStreamSocketAdapter from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from ._tools.actions import ErrorAction as _ErrorAction, RequestAction as _RequestAction from .abc import AbstractAsyncNetworkServer, SupportsEventSet from .handler import AsyncStreamClient, AsyncStreamRequestHandler @@ -68,6 +66,7 @@ AsyncBackend, AsyncListenerSocketAdapter, AsyncStreamSocketAdapter, + CancelScope, IEvent, Task, TaskGroup, @@ -83,7 +82,7 @@ class AsyncTCPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp "__backend", "__listeners", "__listeners_factory", - "__listeners_factory_runner", + "__listeners_factory_scope", "__protocol", "__request_handler", "__is_shutdown", @@ -201,7 +200,7 @@ def _value_or_default(value: float | None, default: float) -> float: backlog=backlog, reuse_port=reuse_port, ) - self.__listeners_factory_runner: SingleTaskRunner[Sequence[AsyncListenerSocketAdapter]] | None = None + self.__listeners_factory_scope: CancelScope | None = None self.__backend: AsyncBackend = backend self.__listeners: tuple[AsyncListenerSocketAdapter, ...] | None = None @@ -240,7 +239,8 @@ def stop_listening(self) -> None: del listener_task async def server_close(self) -> None: - self.__kill_listener_factory_runner() + if self.__listeners_factory_scope is not None: + self.__listeners_factory_scope.cancel() self.__listeners_factory = None await self.__close_listeners() @@ -267,7 +267,6 @@ async def close_listener(listener: AsyncListenerSocketAdapter) -> None: await self.__backend.cancel_shielded_coro_yield() async def shutdown(self) -> None: - self.__kill_listener_factory_runner() if self.__mainloop_task is not None: self.__mainloop_task.cancel() if self.__shutdown_asked: @@ -281,10 +280,6 @@ async def shutdown(self) -> None: shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__ - def __kill_listener_factory_runner(self) -> None: - if self.__listeners_factory_runner is not None: - self.__listeners_factory_runner.cancel() - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: async with _contextlib.AsyncExitStack() as server_exit_stack: is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack()) @@ -301,14 +296,17 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Bind and activate assert self.__listeners is None # nosec assert_used - assert self.__listeners_factory_runner is None # nosec assert_used + assert self.__listeners_factory_scope is None # nosec assert_used if self.__listeners_factory is None: raise ServerClosedError("Closed server") try: - self.__listeners_factory_runner = SingleTaskRunner(self.__backend, self.__listeners_factory) - self.__listeners = tuple(await self.__listeners_factory_runner.run()) + with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: + await self.__backend.coro_yield() + self.__listeners = tuple(await self.__listeners_factory()) + if self.__listeners_factory_scope.cancelled_caught(): + raise ServerClosedError("Closed server") finally: - self.__listeners_factory_runner = None + self.__listeners_factory_scope = None if not self.__listeners: self.__listeners = None raise OSError("empty listeners list") @@ -471,8 +469,8 @@ async def __new_request_handler(self, client: _ConnectedClientAPI[_ResponseT]) - return request_handler_generator async def __force_close_stream_socket(self, socket: AsyncStreamSocketAdapter) -> None: - with _contextlib.suppress(OSError): - await self.__backend.ignore_cancellation(socket.aclose()) + with self.__backend.move_on_after(0), _contextlib.suppress(OSError): + await socket.aclose() @classmethod def __have_errno(cls, exc: OSError | BaseExceptionGroup[OSError], errnos: set[int]) -> bool: @@ -536,7 +534,7 @@ def get_addresses(self) -> Sequence[SocketAddress]: if (listeners := self.__listeners) is None: return () return tuple( - new_socket_address(listener.get_local_address(), listener.socket().family) + new_socket_address(listener.socket().getsockname(), listener.socket().family) for listener in listeners if not listener.is_closing() ) @@ -626,7 +624,7 @@ def __init__( producer: StreamDataProducer[_ResponseT], logger: logging.Logger, ) -> None: - super().__init__(new_socket_address(socket.get_remote_address(), socket.socket().family)) + super().__init__(new_socket_address(socket.socket().getpeername(), socket.socket().family)) self.__socket: AsyncStreamSocketAdapter = socket self.__closed: bool = False @@ -641,18 +639,14 @@ def is_closing(self) -> bool: async def _force_close(self) -> None: self.__closed = True async with self.__send_lock: # If self.aclose() took the lock, wait for it to finish - socket = self.__socket - await self.__shutdown_socket(socket) + pass async def aclose(self) -> None: async with self.__send_lock: socket = self.__socket self.__closed = True - try: - await self.__shutdown_socket(socket) - finally: - with _contextlib.suppress(OSError): - await socket.aclose() + with _contextlib.suppress(OSError): + await socket.aclose() async def send_packet(self, packet: _ResponseT, /) -> None: self.__check_closed() @@ -675,14 +669,6 @@ def __check_closed(self) -> AsyncStreamSocketAdapter: raise ClientClosedError("Closed client") return socket - @staticmethod - async def __shutdown_socket(socket: AsyncStreamSocketAdapter) -> None: - if not isinstance(socket, AsyncHalfCloseableStreamSocketAdapter): - return - with _contextlib.suppress(OSError): - if not socket.is_closing(): - await socket.send_eof() - @property def socket(self) -> SocketProxy: return self.__proxy diff --git a/src/easynetwork/api_async/server/udp.py b/src/easynetwork/api_async/server/udp.py index 2712e961..2fa4803a 100644 --- a/src/easynetwork/api_async/server/udp.py +++ b/src/easynetwork/api_async/server/udp.py @@ -33,19 +33,19 @@ from ...protocol import DatagramProtocol from ...tools._utils import ( check_real_socket_state as _check_real_socket_state, + exception_with_notes as _exception_with_notes, make_callback as _make_callback, remove_traceback_frames_in_place as _remove_traceback_frames_in_place, ) from ...tools.constants import MAX_DATAGRAM_BUFSIZE from ...tools.socket import SocketAddress, SocketProxy, new_socket_address from ..backend.factory import AsyncBackendFactory -from ..backend.tasks import SingleTaskRunner from ._tools.actions import ErrorAction as _ErrorAction, RequestAction as _RequestAction from .abc import AbstractAsyncNetworkServer, SupportsEventSet from .handler import AsyncDatagramClient, AsyncDatagramRequestHandler if TYPE_CHECKING: - from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, ICondition, IEvent, ILock, Task, TaskGroup + from ..backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, CancelScope, ICondition, IEvent, ILock, Task, TaskGroup _KT = TypeVar("_KT") _VT = TypeVar("_VT") @@ -60,7 +60,7 @@ class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp "__backend", "__socket", "__socket_factory", - "__socket_factory_runner", + "__socket_factory_scope", "__protocol", "__request_handler", "__is_shutdown", @@ -121,7 +121,7 @@ def __init__( remote_address=None, reuse_port=reuse_port, ) - self.__socket_factory_runner: SingleTaskRunner[AsyncDatagramSocketAdapter] | None = None + self.__socket_factory_scope: CancelScope | None = None self.__backend: AsyncBackend = backend self.__socket: AsyncDatagramSocketAdapter | None = None @@ -148,7 +148,8 @@ def is_serving(self) -> bool: is_serving.__doc__ = AbstractAsyncNetworkServer.is_serving.__doc__ async def server_close(self) -> None: - self.__kill_socket_factory_runner() + if self.__socket_factory_scope is not None: + self.__socket_factory_scope.cancel() self.__socket_factory = None await self.__close_socket() @@ -170,7 +171,6 @@ async def close_socket(socket: AsyncDatagramSocketAdapter) -> None: exit_stack.push_async_callback(self.__mainloop_task.wait) async def shutdown(self) -> None: - self.__kill_socket_factory_runner() if self.__mainloop_task is not None: self.__mainloop_task.cancel() if self.__shutdown_asked: @@ -184,10 +184,6 @@ async def shutdown(self) -> None: shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__ - def __kill_socket_factory_runner(self) -> None: - if self.__socket_factory_runner is not None: - self.__socket_factory_runner.cancel() - async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: async with _contextlib.AsyncExitStack() as server_exit_stack: is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack()) @@ -204,14 +200,17 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Bind and activate assert self.__socket is None # nosec assert_used - assert self.__socket_factory_runner is None # nosec assert_used + assert self.__socket_factory_scope is None # nosec assert_used if self.__socket_factory is None: raise ServerClosedError("Closed server") try: - self.__socket_factory_runner = SingleTaskRunner(self.__backend, self.__socket_factory) - self.__socket = await self.__socket_factory_runner.run() + with self.__backend.open_cancel_scope() as self.__socket_factory_scope: + await self.__backend.coro_yield() + self.__socket = await self.__socket_factory() + if self.__socket_factory_scope.cancelled_caught() or self.__socket is None: + raise ServerClosedError("Closed server") finally: - self.__socket_factory_runner = None + self.__socket_factory_scope = None ################### # Final teardown @@ -362,11 +361,8 @@ async def __new_request_handler(self, client: _ClientAPI[_ResponseT]) -> AsyncGe def __check_datagram_queue_not_empty(datagram_queue: deque[bytes]) -> None: if len(datagram_queue) == 0: # pragma: no cover msg = "The server has created too many tasks and ends up in an inconsistent state." - try: - raise RuntimeError(msg) - except RuntimeError as exc: - exc.add_note("Please fill an issue (https://github.com/francis-clairicia/EasyNetwork/issues)") - raise + note = "Please fill an issue (https://github.com/francis-clairicia/EasyNetwork/issues)" + raise _exception_with_notes(RuntimeError(msg), note) @_contextlib.contextmanager def __suppress_and_log_remaining_exception(self, client_address: SocketAddress) -> Iterator[None]: @@ -425,7 +421,7 @@ def get_address(self) -> SocketAddress | None: """ if (socket := self.__socket) is None or socket.is_closing(): return None - return new_socket_address(socket.get_local_address(), socket.socket().family) + return new_socket_address(socket.socket().getsockname(), socket.socket().family) def get_backend(self) -> AsyncBackend: return self.__backend diff --git a/src/easynetwork/api_sync/server/_base.py b/src/easynetwork/api_sync/server/_base.py index a34e1618..49f8809f 100644 --- a/src/easynetwork/api_sync/server/_base.py +++ b/src/easynetwork/api_sync/server/_base.py @@ -18,11 +18,12 @@ __all__ = ["BaseStandaloneNetworkServerImpl"] +import concurrent.futures import contextlib as _contextlib import threading as _threading import time -from collections.abc import Callable, Coroutine, Iterator -from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, final +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from ...api_async.backend.abc import ThreadsPortal from ...api_async.server.abc import SupportsEventSet @@ -31,39 +32,29 @@ from .abc import AbstractNetworkServer if TYPE_CHECKING: - from ...api_async.backend.abc import AsyncBackend, Runner from ...api_async.server.abc import AbstractAsyncNetworkServer -_P = ParamSpec("_P") -_T = TypeVar("_T") - - class BaseStandaloneNetworkServerImpl(AbstractNetworkServer): __slots__ = ( "__server", - "__runner", "__close_lock", "__bootstrap_lock", "__threads_portal", "__is_shutdown", + "__is_closed", ) def __init__(self, server: AbstractAsyncNetworkServer) -> None: super().__init__() self.__server: AbstractAsyncNetworkServer = server - self.__threads_portal: _ServerThreadsPortal | None = None + self.__threads_portal: ThreadsPortal | None = None self.__is_shutdown = _threading.Event() self.__is_shutdown.set() - self.__runner: Runner | None = self.__server.get_backend().new_runner() + self.__is_closed = _threading.Event() self.__close_lock = ForkSafeLock() self.__bootstrap_lock = ForkSafeLock() - def __enter__(self) -> Self: - assert self.__runner is not None, "Server is entered twice" # nosec assert_used - self.__runner.__enter__() - return super().__enter__() - def is_serving(self) -> bool: if (portal := self._portal) is not None: with _contextlib.suppress(RuntimeError): @@ -75,23 +66,19 @@ def is_serving(self) -> bool: def server_close(self) -> None: with self.__close_lock.get(), _contextlib.ExitStack() as stack, _contextlib.suppress(RuntimeError): if (portal := self._portal) is not None: - CancelledError = self.__server.get_backend().get_cancelled_exc_class() - with _contextlib.suppress(CancelledError): + with _contextlib.suppress(concurrent.futures.CancelledError): portal.run_coroutine(self.__server.server_close) else: - runner, self.__runner = self.__runner, None - if runner is None: - return - stack.push(runner) + stack.callback(self.__is_closed.set) self.__is_shutdown.wait() # Ensure we are not in the interval between the server shutdown and the scheduler shutdown - runner.run(self.__server.server_close) + backend = self.__server.get_backend() + backend.bootstrap(self.__server.server_close) server_close.__doc__ = AbstractNetworkServer.server_close.__doc__ def shutdown(self, timeout: float | None = None) -> None: if (portal := self._portal) is not None: - CancelledError = self.__server.get_backend().get_cancelled_exc_class() - with _contextlib.suppress(RuntimeError, CancelledError): + try: # If shutdown() have been cancelled, that means the scheduler itself is shutting down, and this is what we want if timeout is None: portal.run_coroutine(self.__server.shutdown) @@ -101,16 +88,35 @@ def shutdown(self, timeout: float | None = None) -> None: portal.run_coroutine(self.__do_shutdown_with_timeout, timeout) finally: timeout -= time.perf_counter() - _start + except (RuntimeError, concurrent.futures.CancelledError): + pass self.__is_shutdown.wait(timeout) shutdown.__doc__ = AbstractNetworkServer.shutdown.__doc__ async def __do_shutdown_with_timeout(self, timeout_delay: float) -> None: backend = self.__server.get_backend() - async with backend.move_on_after(timeout_delay): + with backend.move_on_after(timeout_delay): await self.__server.shutdown() - def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: + def serve_forever( + self, + *, + is_up_event: SupportsEventSet | None = None, + runner_options: Mapping[str, Any] | None = None, + ) -> None: + """ + Starts the server's main loop. + + Parameters: + is_up_event: If given, will be triggered when the server is ready to accept new clients. + runner_options: Options to pass to the :meth:`~AsyncBackend.bootstrap` method. + + Raises: + ServerClosedError: The server is closed. + ServerAlreadyRunning: Another task already called :meth:`serve_forever`. + """ + backend = self.__server.get_backend() with _contextlib.ExitStack() as server_exit_stack, _contextlib.suppress(backend.get_cancelled_exc_class()): if is_up_event is not None: @@ -123,8 +129,7 @@ def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: locks_stack.enter_context(self.__close_lock.get()) locks_stack.enter_context(self.__bootstrap_lock.get()) - runner = self.__runner - if runner is None: + if self.__is_closed.is_set(): raise ServerClosedError("Closed server") if not self.__is_shutdown.is_set(): @@ -133,25 +138,23 @@ def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None: self.__is_shutdown.clear() server_exit_stack.callback(self.__is_shutdown.set) - async def serve_forever(runner: Runner) -> None: - try: - self.__threads_portal = _ServerThreadsPortal(backend, runner) - server_exit_stack.callback(self.__threads_portal._wait_for_all_requests) + async def serve_forever() -> None: + def reset_threads_portal() -> None: + self.__threads_portal = None + + def acquire_bootstrap_lock() -> None: + locks_stack.enter_context(self.__bootstrap_lock.get()) + server_exit_stack.callback(reset_threads_portal) + server_exit_stack.callback(acquire_bootstrap_lock) + + async with backend.create_threads_portal() as self.__threads_portal: # Initialization finished; release the locks locks_stack.close() await self.__server.serve_forever(is_up_event=is_up_event) - finally: - self.__threads_portal = None - - try: - runner.run(serve_forever, runner) - finally: - # Acquire the bootstrap lock at teardown, before calling is_shutdown.set(). - locks_stack.enter_context(self.__bootstrap_lock.get()) - serve_forever.__doc__ = AbstractNetworkServer.serve_forever.__doc__ + backend.bootstrap(serve_forever, runner_options=runner_options) @property def _server(self) -> AbstractAsyncNetworkServer: @@ -161,39 +164,3 @@ def _server(self) -> AbstractAsyncNetworkServer: def _portal(self) -> ThreadsPortal | None: with self.__bootstrap_lock.get(): return self.__threads_portal - - -@final -class _ServerThreadsPortal(ThreadsPortal): - __slots__ = ("__backend", "__runner", "__portal", "__request_count", "__request_count_lock") - - def __init__(self, backend: AsyncBackend, runner: Runner) -> None: - super().__init__() - self.__backend: AsyncBackend = backend - self.__runner: Runner = runner - self.__portal: ThreadsPortal = backend.create_threads_portal() - self.__request_count: int = 0 - self.__request_count_lock = ForkSafeLock() - - def run_coroutine(self, coro_func: Callable[_P, Coroutine[Any, Any, _T]], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: - with self.__request_context(): - return self.__portal.run_coroutine(coro_func, *args, **kwargs) - - def run_sync(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: - with self.__request_context(): - return self.__portal.run_sync(func, *args, **kwargs) - - def _wait_for_all_requests(self) -> None: - while self.__request_count > 0: - self.__runner.run(self.__backend.coro_yield) - - @_contextlib.contextmanager - def __request_context(self) -> Iterator[None]: - request_count_lock = self.__request_count_lock - with request_count_lock.get(): - self.__request_count += 1 - try: - yield - finally: - with request_count_lock.get(): - self.__request_count -= 1 diff --git a/src/easynetwork/tools/_utils.py b/src/easynetwork/tools/_utils.py index 3962ee02..db6318d5 100644 --- a/src/easynetwork/tools/_utils.py +++ b/src/easynetwork/tools/_utils.py @@ -20,6 +20,7 @@ "check_socket_no_ssl", "ensure_datagram_socket_bound", "error_from_errno", + "exception_with_notes", "is_ssl_eof_error", "is_ssl_socket", "lock_with_timeout", @@ -29,12 +30,10 @@ "retry_socket_method", "retry_ssl_socket_method", "set_reuseport", - "transform_future_exception", "validate_timeout_delay", "wait_socket_available", ] -import concurrent.futures import contextlib import errno as _errno import functools @@ -43,7 +42,7 @@ import socket as _socket import threading import time -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterable, Iterator from math import isinf, isnan from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeGuard, TypeVar, assert_never @@ -294,17 +293,11 @@ def set_reuseport(sock: SupportsSocketOptions) -> None: raise ValueError("reuse_port not supported by socket module, SO_REUSEPORT defined but not implemented.") from None -def transform_future_exception(exc: BaseException) -> BaseException: - match exc: - case SystemExit() | KeyboardInterrupt(): - cancel_exc = concurrent.futures.CancelledError().with_traceback(exc.__traceback__) - try: - cancel_exc.__cause__ = cancel_exc.__context__ = exc - exc = cancel_exc - finally: - del cancel_exc - case _: - pass +def exception_with_notes(exc: _ExcType, notes: str | Iterable[str]) -> _ExcType: + if isinstance(notes, str): + notes = (notes,) + for note in notes: + exc.add_note(note) return exc diff --git a/src/easynetwork_asyncio/__init__.py b/src/easynetwork_asyncio/__init__.py index 361cc7bc..cbd4aa32 100644 --- a/src/easynetwork_asyncio/__init__.py +++ b/src/easynetwork_asyncio/__init__.py @@ -17,8 +17,6 @@ from __future__ import annotations -__all__ = ["AsyncioBackend"] # type: list[str] +__all__ = ["AsyncIOBackend"] # type: list[str] -__version__ = "1.0.0" - -from .backend import AsyncioBackend +from .backend import AsyncIOBackend diff --git a/src/easynetwork_asyncio/backend.py b/src/easynetwork_asyncio/backend.py index df5751f0..9409662e 100644 --- a/src/easynetwork_asyncio/backend.py +++ b/src/easynetwork_asyncio/backend.py @@ -17,18 +17,19 @@ from __future__ import annotations -__all__ = ["AsyncioBackend"] +__all__ = ["AsyncIOBackend"] import asyncio import asyncio.base_events import contextvars import functools import itertools +import math import os import socket as _socket import sys -from collections.abc import Callable, Coroutine, Sequence -from contextlib import AbstractAsyncContextManager as AsyncContextManager +from collections.abc import Callable, Coroutine, Mapping, Sequence +from contextlib import closing from typing import TYPE_CHECKING, Any, NoReturn, ParamSpec, TypeVar try: @@ -45,10 +46,9 @@ from ._utils import create_connection, create_datagram_socket, ensure_resolved, open_listener_sockets_from_getaddrinfo_result from .datagram.endpoint import create_datagram_endpoint from .datagram.socket import AsyncioTransportDatagramSocketAdapter, RawDatagramSocketAdapter -from .runner import AsyncioRunner from .stream.listener import AcceptedSocket, AcceptedSSLSocket, ListenerSocketAdapter from .stream.socket import AsyncioTransportStreamSocketAdapter, RawStreamSocketAdapter -from .tasks import SystemTask, TaskGroup, TimeoutHandle, move_on_after, move_on_at, timeout, timeout_at +from .tasks import CancelScope, TaskGroup, TaskUtils from .threads import ThreadsPortal if TYPE_CHECKING: @@ -62,34 +62,27 @@ _T_co = TypeVar("_T_co", covariant=True) -class AsyncioBackend(AbstractAsyncBackend): - __slots__ = ("__use_asyncio_transport", "__asyncio_runner_factory") +class AsyncIOBackend(AbstractAsyncBackend): + __slots__ = ("__use_asyncio_transport",) - def __init__(self, *, transport: bool = True, runner_factory: Callable[[], asyncio.Runner] | None = None) -> None: + def __init__(self, *, transport: bool = True) -> None: self.__use_asyncio_transport: bool = bool(transport) - self.__asyncio_runner_factory: Callable[[], asyncio.Runner] = runner_factory or asyncio.Runner - def new_runner(self) -> AsyncioRunner: - return AsyncioRunner(self.__asyncio_runner_factory()) - - @staticmethod - def _current_asyncio_task() -> asyncio.Task[Any]: - t: asyncio.Task[Any] | None = asyncio.current_task() - if t is None: # pragma: no cover - raise RuntimeError("This function should be called within a task.") - return t + def bootstrap( + self, + coro_func: Callable[..., Coroutine[Any, Any, _T]], + *args: Any, + runner_options: Mapping[str, Any] | None = None, + ) -> _T: + # Avoid ResourceWarning by always closing the coroutine + with asyncio.Runner(**(runner_options or {})) as runner, closing(coro_func(*args)) as coro: + return runner.run(coro) async def coro_yield(self) -> None: await asyncio.sleep(0) async def cancel_shielded_coro_yield(self) -> None: - current_task: asyncio.Task[Any] = self._current_asyncio_task() - try: - await asyncio.sleep(0) - except asyncio.CancelledError as exc: - TimeoutHandle._reschedule_delayed_task_cancel(current_task, self._get_cancelled_error_message(exc)) - finally: - del current_task + await TaskUtils.cancel_shielded_coro_yield() def get_cancelled_exc_class(self) -> type[BaseException]: return asyncio.CancelledError @@ -97,70 +90,10 @@ def get_cancelled_exc_class(self) -> type[BaseException]: async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, _T_co]) -> _T_co: if not asyncio.iscoroutine(coroutine): raise TypeError("Expected a coroutine object") - task: asyncio.Task[_T_co] = asyncio.create_task(coroutine) - - # This task must be unregistered in order not to be cancelled by runner at event loop shutdown - asyncio._unregister_task(task) - - try: - await self._cancel_shielded_wait_asyncio_future(task, None) - assert task.done() # nosec assert_used - return task.result() - finally: - del task - - def timeout(self, delay: float) -> AsyncContextManager[TimeoutHandle]: - return timeout(delay) - - def timeout_at(self, deadline: float) -> AsyncContextManager[TimeoutHandle]: - return timeout_at(deadline) + return await TaskUtils.cancel_shielded_await_task(asyncio.create_task(coroutine)) - def move_on_after(self, delay: float) -> AsyncContextManager[TimeoutHandle]: - return move_on_after(delay) - - def move_on_at(self, deadline: float) -> AsyncContextManager[TimeoutHandle]: - return move_on_at(deadline) - - @classmethod - async def _cancel_shielded_wait_asyncio_future( - cls, - future: asyncio.Future[Any], - abort_func: Callable[[], bool] | None, - ) -> None: - current_task: asyncio.Task[Any] = cls._current_asyncio_task() - abort: bool | None = None - task_cancelled: bool = False - task_cancel_msg: str | None = None - - try: - while not future.done(): - try: - await asyncio.wait({future}) - except asyncio.CancelledError as exc: - if abort is None: - if abort_func is None: - abort = False - else: - abort = bool(abort_func()) - if abort: - raise - task_cancelled = True - task_cancel_msg = cls._get_cancelled_error_message(exc) - - if task_cancelled and not future.cancelled(): - TimeoutHandle._reschedule_delayed_task_cancel(current_task, task_cancel_msg) - finally: - task_cancel_msg = None - del current_task, future, abort_func - - @staticmethod - def _get_cancelled_error_message(exc: asyncio.CancelledError) -> str | None: - msg: str | None - if exc.args: - msg = exc.args[0] - else: - msg = None - return msg + def open_cancel_scope(self, *, deadline: float = math.inf) -> CancelScope: + return CancelScope(deadline=deadline) def current_time(self) -> float: loop = asyncio.get_running_loop() @@ -174,15 +107,6 @@ async def sleep_forever(self) -> NoReturn: await loop.create_future() raise AssertionError("await an unused future cannot end in any other way than by cancellation") - def spawn_task( - self, - coro_func: Callable[..., Coroutine[Any, Any, _T]], - /, - *args: Any, - context: contextvars.Context | None = None, - ) -> SystemTask[_T]: - return SystemTask(coro_func(*args), context=context) - def create_task_group(self) -> TaskGroup: return TaskGroup() @@ -437,8 +361,7 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg future = loop.run_in_executor(None, func_call) del func_call, func, args, kwargs try: - await self._cancel_shielded_wait_asyncio_future(future, None) - assert future.done() # nosec assert_used + await TaskUtils.cancel_shielded_wait_asyncio_futures({future}) return future.result() finally: del future @@ -453,15 +376,14 @@ async def wait_future(self, future: concurrent.futures.Future[_T_co]) -> _T_co: # If future.cancel() failed, that means future.set_running_or_notify_cancel() has been called # and set future in RUNNING state. # This future cannot be cancelled anymore, therefore it must be awaited. - await self._cancel_shielded_wait_asyncio_future(future_wrapper, future.cancel) + await TaskUtils.cancel_shielded_wait_asyncio_futures({future_wrapper}, abort_func=future.cancel) + + # Unwrap "future_wrapper" to prevent reports about unhandled exceptions. if not future_wrapper.cancelled(): del future - # Unwrap "future_wrapper" instead to prevent reports about unhandled exceptions. - assert future_wrapper.done() # nosec assert_used return future_wrapper.result() finally: del future_wrapper - assert future.done() # nosec assert_used try: if future.cancelled(): diff --git a/src/easynetwork_asyncio/datagram/endpoint.py b/src/easynetwork_asyncio/datagram/endpoint.py index 94b1c6ba..d4d78ad8 100644 --- a/src/easynetwork_asyncio/datagram/endpoint.py +++ b/src/easynetwork_asyncio/datagram/endpoint.py @@ -32,6 +32,8 @@ from easynetwork.tools._utils import error_from_errno as _error_from_errno +from ..tasks import TaskUtils + if TYPE_CHECKING: import asyncio.trsock @@ -104,12 +106,19 @@ def is_closing(self) -> bool: async def recvfrom(self) -> tuple[bytes, tuple[Any, ...]]: self.__check_exceptions() if self.__transport.is_closing(): - raise _error_from_errno(_errno.ECONNABORTED) - data_and_address = await self.__recv_queue.get() - if data_and_address is None: - self.__check_exceptions() # Woken up because an error occurred ? - assert self.__transport.is_closing() # nosec assert_used - raise _error_from_errno(_errno.ECONNABORTED) # Connection lost otherwise + try: + data_and_address = self.__recv_queue.get_nowait() + except asyncio.QueueEmpty: + data_and_address = None + if data_and_address is None: + raise _error_from_errno(_errno.ECONNABORTED) + await TaskUtils.cancel_shielded_coro_yield() + else: + data_and_address = await self.__recv_queue.get() + if data_and_address is None: + self.__check_exceptions() # Woken up because an error occurred ? + assert self.__transport.is_closing() # nosec assert_used + raise _error_from_errno(_errno.ECONNABORTED) # Connection lost otherwise return data_and_address async def sendto(self, data: bytes | bytearray | memoryview, address: tuple[Any, ...] | None = None, /) -> None: @@ -122,9 +131,6 @@ async def sendto(self, data: bytes | bytearray | memoryview, address: tuple[Any, def get_extra_info(self, name: str, default: Any = None) -> Any: return self.__transport.get_extra_info(name, default) - def get_loop(self) -> asyncio.AbstractEventLoop: - return self.__protocol._get_loop() - def __check_exceptions(self) -> None: try: exc = self.__exception_queue.get_nowait() @@ -167,7 +173,7 @@ def __init__( self.__loop: asyncio.AbstractEventLoop = loop self.__recv_queue: asyncio.Queue[tuple[bytes, tuple[Any, ...]] | None] = recv_queue self.__exception_queue: asyncio.Queue[Exception] = exception_queue - self.__transport: asyncio.BaseTransport | None = None + self.__transport: asyncio.DatagramTransport | None = None self.__closed: asyncio.Future[None] = loop.create_future() self.__drain_waiters: collections.deque[asyncio.Future[None]] = collections.deque() self.__write_paused: bool = False @@ -183,20 +189,11 @@ def __del__(self) -> None: # pragma: no cover if closed.done() and not closed.cancelled(): closed.exception() - def connection_made(self, transport: asyncio.BaseTransport) -> None: + def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore[override] assert self.__transport is None, "Transport already set" # nosec assert_used self.__transport = transport self.__connection_lost = False - - peername: tuple[Any, ...] | None = transport.get_extra_info("peername", None) - if peername is not None and isinstance(self.__loop, asyncio.base_events.BaseEventLoop): - # There is an asyncio issue where the private address attribute is not updated with the actual remote address - # if the transport is instanciated with an external socket: - # await loop.create_datagram_endpoint(sock=my_socket) - # - # This is a monkeypatch to force update the internal address attribute - if hasattr(transport, "_address") and getattr(transport, "_address") != peername: - setattr(transport, "_address", peername) + _monkeypatch_transport(transport, self.__loop) def connection_lost(self, exc: Exception | None) -> None: self.__connection_lost = True @@ -269,3 +266,15 @@ def _get_loop(self) -> asyncio.AbstractEventLoop: def _writing_paused(self) -> bool: return self.__write_paused + + +def _monkeypatch_transport(transport: asyncio.DatagramTransport, loop: asyncio.AbstractEventLoop) -> None: + if isinstance(loop, asyncio.base_events.BaseEventLoop) and hasattr(transport, "_address"): + # There is an asyncio issue where the private address attribute is not updated with the actual remote address + # if the transport is instanciated with an external socket: + # await loop.create_datagram_endpoint(sock=my_socket) + # + # This is a monkeypatch to force update the internal address attribute + peername: tuple[Any, ...] | None = transport.get_extra_info("peername", None) + if peername is not None and getattr(transport, "_address") != peername: + setattr(transport, "_address", peername) diff --git a/src/easynetwork_asyncio/datagram/socket.py b/src/easynetwork_asyncio/datagram/socket.py index bb6318f1..f58088fe 100644 --- a/src/easynetwork_asyncio/datagram/socket.py +++ b/src/easynetwork_asyncio/datagram/socket.py @@ -20,12 +20,10 @@ __all__ = ["AsyncioTransportDatagramSocketAdapter", "RawDatagramSocketAdapter"] import asyncio -import errno import socket as _socket from typing import TYPE_CHECKING, Any, final from easynetwork.api_async.backend.abc import AsyncDatagramSocketAdapter as AbstractAsyncDatagramSocketAdapter -from easynetwork.tools._utils import error_from_errno as _error_from_errno from ..socket import AsyncSocket @@ -62,15 +60,6 @@ async def aclose(self) -> None: def is_closing(self) -> bool: return self.__endpoint.is_closing() - def get_local_address(self) -> tuple[Any, ...]: - local_address: tuple[Any, ...] | None = self.__endpoint.get_extra_info("sockname") - if local_address is None: - raise _error_from_errno(errno.ENOTSOCK) - return local_address - - def get_remote_address(self) -> tuple[Any, ...] | None: - return self.__endpoint.get_extra_info("peername") - async def recvfrom(self, bufsize: int, /) -> tuple[bytes, tuple[Any, ...]]: data, address = await self.__endpoint.recvfrom() if len(data) > bufsize: @@ -106,15 +95,6 @@ async def aclose(self) -> None: def is_closing(self) -> bool: return self.__socket.is_closing() - def get_local_address(self) -> tuple[Any, ...]: - return self.__socket.socket.getsockname() - - def get_remote_address(self) -> tuple[Any, ...] | None: - try: - return self.__socket.socket.getpeername() - except OSError: - return None - async def recvfrom(self, bufsize: int, /) -> tuple[bytes, tuple[Any, ...]]: return await self.__socket.recvfrom(bufsize) diff --git a/src/easynetwork_asyncio/runner.py b/src/easynetwork_asyncio/runner.py deleted file mode 100644 index b3077351..00000000 --- a/src/easynetwork_asyncio/runner.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -"""asyncio engine for easynetwork.api_async -""" - -from __future__ import annotations - -__all__ = ["AsyncioRunner"] - -import asyncio -import contextlib -from collections.abc import Callable, Coroutine -from typing import Any, Self, TypeVar - -from easynetwork.api_async.backend.abc import Runner as AbstractRunner - -_T = TypeVar("_T") - - -class AsyncioRunner(AbstractRunner): - __slots__ = ("__runner",) - - def __init__(self, runner: asyncio.Runner) -> None: - super().__init__() - - self.__runner: asyncio.Runner = runner - - def __enter__(self) -> Self: - self.__runner.__enter__() - return super().__enter__() - - def close(self) -> None: - return self.__runner.close() - - def run(self, coro_func: Callable[..., Coroutine[Any, Any, _T]], *args: Any) -> _T: - with contextlib.closing(coro_func(*args)) as coro: # Avoid ResourceWarning by always closing the coroutine - return self.__runner.run(coro) diff --git a/src/easynetwork_asyncio/socket.py b/src/easynetwork_asyncio/socket.py index 8a744ea8..3f86db05 100644 --- a/src/easynetwork_asyncio/socket.py +++ b/src/easynetwork_asyncio/socket.py @@ -30,6 +30,8 @@ from easynetwork.tools._utils import check_socket_no_ssl as _check_socket_no_ssl, error_from_errno as _error_from_errno +from .tasks import TaskUtils + if TYPE_CHECKING: from types import TracebackType @@ -152,9 +154,7 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int | Non if task_id in self.__waiters: raise _error_from_errno(_errno.EBUSY) - task = asyncio.current_task(self.__loop) - if task is None: # pragma: no cover - raise RuntimeError("This function will not be executed with the bound event loop.") + task = TaskUtils.current_asyncio_task(self.__loop) with contextlib.ExitStack() as stack: self.__tasks.add(task) @@ -179,7 +179,7 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int | Non except asyncio.CancelledError: if self.__socket is not None: raise - if (not task.cancelling() > task_cancelling) or task.uncancel() > task_cancelling: + if task.cancelling() <= task_cancelling or task.uncancel() > task_cancelling: raise raise _error_from_errno(abort_errno) from None finally: diff --git a/src/easynetwork_asyncio/stream/listener.py b/src/easynetwork_asyncio/stream/listener.py index 6dcd0334..ca8cef21 100644 --- a/src/easynetwork_asyncio/stream/listener.py +++ b/src/easynetwork_asyncio/stream/listener.py @@ -23,7 +23,7 @@ import asyncio.streams from abc import abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, final from easynetwork.api_async.backend.abc import ( AcceptedSocket as AbstractAcceptedSocket, @@ -65,9 +65,6 @@ async def accept(self) -> AbstractAcceptedSocket: client_socket, _ = await self.__socket.accept() return self.__accepted_socket_factory(client_socket, self.__socket.loop) - def get_local_address(self) -> tuple[Any, ...]: - return self.__socket.socket.getsockname() - def socket(self) -> asyncio.trsock.TransportSocket: return self.__socket.socket diff --git a/src/easynetwork_asyncio/stream/socket.py b/src/easynetwork_asyncio/stream/socket.py index fc4c8031..ddde25c3 100644 --- a/src/easynetwork_asyncio/stream/socket.py +++ b/src/easynetwork_asyncio/stream/socket.py @@ -20,16 +20,11 @@ __all__ = ["AsyncioTransportStreamSocketAdapter", "RawStreamSocketAdapter"] import asyncio -import errno import socket as _socket from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Self, cast, final +from typing import TYPE_CHECKING, final -from easynetwork.api_async.backend.abc import ( - AsyncHalfCloseableStreamSocketAdapter as AbstractAsyncHalfCloseableStreamSocketAdapter, - AsyncStreamSocketAdapter as AbstractAsyncStreamSocketAdapter, -) -from easynetwork.tools._utils import error_from_errno as _error_from_errno +from easynetwork.api_async.backend.abc import AsyncStreamSocketAdapter as AbstractAsyncStreamSocketAdapter from ..socket import AsyncSocket @@ -37,6 +32,7 @@ import asyncio.trsock +@final class AsyncioTransportStreamSocketAdapter(AbstractAsyncStreamSocketAdapter): __slots__ = ( "__reader", @@ -44,15 +40,6 @@ class AsyncioTransportStreamSocketAdapter(AbstractAsyncStreamSocketAdapter): "__socket", ) - def __new__( - cls, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - ) -> Self: - if cls is AsyncioTransportStreamSocketAdapter and writer.can_write_eof(): - return cast(Self, super().__new__(AsyncioTransportHalfCloseableStreamSocketAdapter)) - return super().__new__(cls) - def __init__( self, reader: asyncio.StreamReader, @@ -78,18 +65,6 @@ async def aclose(self) -> None: def is_closing(self) -> bool: return self.__writer.is_closing() - def get_local_address(self) -> tuple[Any, ...]: - local_address: tuple[Any, ...] | None = self.__writer.get_extra_info("sockname") - if local_address is None: - raise _error_from_errno(errno.ENOTSOCK) - return local_address - - def get_remote_address(self) -> tuple[Any, ...]: - remote_address: tuple[Any, ...] | None = self.__writer.get_extra_info("peername") - if remote_address is None: - raise _error_from_errno(errno.ENOTCONN) - return remote_address - async def recv(self, bufsize: int, /) -> bytes: if bufsize < 0: raise ValueError("'bufsize' must be a positive or null integer") @@ -103,8 +78,7 @@ async def sendall_fromiter(self, iterable_of_data: Iterable[bytes], /) -> None: self.__writer.writelines(iterable_of_data) await self.__writer.drain() - async def _send_eof_impl(self) -> None: - assert self.__writer.can_write_eof() # nosec assert_used + async def send_eof(self) -> None: self.__writer.write_eof() await asyncio.sleep(0) @@ -113,22 +87,7 @@ def socket(self) -> asyncio.trsock.TransportSocket: @final -class AsyncioTransportHalfCloseableStreamSocketAdapter( - AsyncioTransportStreamSocketAdapter, - AbstractAsyncHalfCloseableStreamSocketAdapter, -): - __slots__ = () - - def __new__(cls, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> Self: - if not writer.can_write_eof(): - raise ValueError(f"{writer!r} cannot write eof") - return super().__new__(cls, reader, writer) - - send_eof = AsyncioTransportStreamSocketAdapter._send_eof_impl - - -@final -class RawStreamSocketAdapter(AbstractAsyncHalfCloseableStreamSocketAdapter): +class RawStreamSocketAdapter(AbstractAsyncStreamSocketAdapter): __slots__ = ("__socket",) def __init__( @@ -149,12 +108,6 @@ async def aclose(self) -> None: def is_closing(self) -> bool: return self.__socket.is_closing() - def get_local_address(self) -> tuple[Any, ...]: - return self.__socket.socket.getsockname() - - def get_remote_address(self) -> tuple[Any, ...]: - return self.__socket.socket.getpeername() - async def recv(self, bufsize: int, /) -> bytes: return await self.__socket.recv(bufsize) diff --git a/src/easynetwork_asyncio/tasks.py b/src/easynetwork_asyncio/tasks.py index b25c557f..b1c97f5b 100644 --- a/src/easynetwork_asyncio/tasks.py +++ b/src/easynetwork_asyncio/tasks.py @@ -17,21 +17,21 @@ from __future__ import annotations -__all__ = ["Task", "TaskGroup", "TimeoutHandle", "timeout", "timeout_at"] +__all__ = ["CancelScope", "Task", "TaskGroup", "TaskUtils"] import asyncio import contextvars +import enum import math from collections import deque -from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, final +from collections.abc import Callable, Coroutine, Iterable +from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, Self, TypeVar, final from weakref import WeakKeyDictionary from easynetwork.api_async.backend.abc import ( - SystemTask as AbstractSystemTask, + CancelScope as AbstractCancelScope, Task as AbstractTask, TaskGroup as AbstractTaskGroup, - TimeoutHandle as AbstractTimeoutHandle, ) if TYPE_CHECKING: @@ -43,6 +43,7 @@ _T_co = TypeVar("_T_co", covariant=True) +@final class Task(AbstractTask[_T_co]): __slots__ = ("__t", "__h") @@ -50,6 +51,9 @@ def __init__(self, task: asyncio.Task[_T_co]) -> None: self.__t: asyncio.Task[_T_co] = task self.__h: int | None = None + def __repr__(self) -> str: + return f"" + def __hash__(self) -> int: if (h := self.__h) is None: self.__h = h = hash((self.__class__, self.__t, 0xFF)) @@ -71,43 +75,17 @@ def cancelled(self) -> bool: async def wait(self) -> None: task = self.__t - try: - if task.done(): - return - await asyncio.wait({task}) - finally: - del task, self # This is needed to avoid circular reference with raised exception + await asyncio.wait({task}) async def join(self) -> _T_co: - # If the caller cancels the join() task, it should not stop the inner task - # e.g. when awaiting from an another task than the one which creates the TaskGroup, - # you want to stop joining the sub-task, not accidentally cancel it. - # It is primarily to avoid error prone code where tasks were not explicitly cancelled using task.cancel() task = self.__t try: return await asyncio.shield(task) finally: del task, self # This is needed to avoid circular reference with raised exception - @property - def _asyncio_task(self) -> asyncio.Task[_T_co]: - return self.__t - - -@final -class SystemTask(Task[_T_co], AbstractSystemTask[_T_co]): - __slots__ = () - - def __init__( - self, - coroutine: Coroutine[Any, Any, _T_co], - *, - context: contextvars.Context | None = None, - ) -> None: - super().__init__(asyncio.create_task(coroutine, context=context)) - async def join_or_cancel(self) -> _T_co: - task = self._asyncio_task + task = self.__t try: return await task finally: @@ -149,89 +127,177 @@ def start_soon( return Task(self.__asyncio_tg.create_task(coro_func(*args), context=context)) -@final -class TimeoutHandle(AbstractTimeoutHandle): - __slots__ = ("__handle", "__only_move_on", "__already_delayed_cancellation") +class _ScopeState(enum.Enum): + CREATED = "created" + ENTERED = "entered" + EXITED = "exited" - __current_handle_dict: WeakKeyDictionary[asyncio.Task[Any], deque[TimeoutHandle]] = WeakKeyDictionary() - __delayed_task_cancel_dict: WeakKeyDictionary[asyncio.Task[Any], asyncio.Handle] = WeakKeyDictionary() - def __init__(self, handle: asyncio.Timeout, *, only_move_on: bool = False) -> None: - super().__init__() - self.__handle: asyncio.Timeout = handle - self.__only_move_on: bool = bool(only_move_on) - self.__already_delayed_cancellation: bool = True +class _DelayedCancel(NamedTuple): + handle: asyncio.Handle + message: str | None - async def __aenter__(self) -> Self: - timeout_handle: asyncio.Timeout = self.__handle - await type(timeout_handle).__aenter__(timeout_handle) - current_task = asyncio.current_task() - assert current_task is not None # nosec assert_used - current_handle_dict = self.__current_handle_dict - if current_task not in current_handle_dict: - current_handle_dict[current_task] = deque() - current_task.add_done_callback(current_handle_dict.pop) - current_handle_dict[current_task].appendleft(self) - self.__already_delayed_cancellation = current_task in self.__delayed_task_cancel_dict + +@final +class CancelScope(AbstractCancelScope): + __slots__ = ( + "__host_task", + "__state", + "__cancel_called", + "__cancelled_caught", + "__task_cancelling", + "__deadline", + "__timeout_handle", + "__delayed_cancellation_on_enter", + ) + + __current_task_scope_dict: WeakKeyDictionary[asyncio.Task[Any], deque[CancelScope]] = WeakKeyDictionary() + __delayed_task_cancel_dict: WeakKeyDictionary[asyncio.Task[Any], _DelayedCancel] = WeakKeyDictionary() + + def __init__(self, *, deadline: float = math.inf) -> None: + super().__init__() + self.__host_task: asyncio.Task[Any] | None = None + self.__state: _ScopeState = _ScopeState.CREATED + self.__cancel_called: bool = False + self.__cancelled_caught: bool = False + self.__task_cancelling: int = 0 + self.__deadline: float = math.inf + self.__timeout_handle: asyncio.TimerHandle | None = None + self.reschedule(deadline) + + def __repr__(self) -> str: + active = self.__state is _ScopeState.ENTERED + cancel_called = self.__cancel_called + cancelled_caught = self.__cancelled_caught + host_task = self.__host_task + deadline = self.__deadline + + info = f"{active=!r}, {cancelled_caught=!r}, {cancel_called=!r}, {host_task=!r}, {deadline=!r}" + return f"<{self.__class__.__name__}({info})>" + + def __enter__(self) -> Self: + if self.__state is not _ScopeState.CREATED: + raise RuntimeError("CancelScope entered twice") + + self.__host_task = current_task = TaskUtils.current_asyncio_task() + self.__task_cancelling = current_task.cancelling() + + current_task_scope = self.__current_task_scope_dict + if current_task not in current_task_scope: + current_task_scope[current_task] = deque() + current_task.add_done_callback(current_task_scope.pop) + current_task_scope[current_task].appendleft(self) + + self.__state = _ScopeState.ENTERED + + if self.__cancel_called: + self.__deliver_cancellation() + else: + self.__timeout() return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - timeout_handle: asyncio.Timeout = self.__handle - current_task = asyncio.current_task() - assert current_task is not None # nosec assert_used + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool: + if self.__state is not _ScopeState.ENTERED: + raise RuntimeError("This cancel scope is not active") + + if TaskUtils.current_asyncio_task() is not self.__host_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + + if self._current_task_scope(self.__host_task) is not self: + raise RuntimeError("Attempted to exit a cancel scope that isn't the current tasks's current cancel scope") + + self.__state = _ScopeState.EXITED + + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + + host_task, self.__host_task = self.__host_task, None + self.__current_task_scope_dict[host_task].popleft() + + if self.__cancel_called: + task_cancelling = host_task.uncancel() + if isinstance(exc_val, asyncio.CancelledError): + self.__cancelled_caught = task_cancelling <= self.__task_cancelling or self.__cancellation_id() in exc_val.args + + delayed_task_cancel = self.__delayed_task_cancel_dict.get(host_task, None) + if delayed_task_cancel is not None and delayed_task_cancel.message == self.__cancellation_id(): + del self.__delayed_task_cancel_dict[host_task] + delayed_task_cancel.handle.cancel() + + current_task_scope = self._current_task_scope(host_task) + if current_task_scope is None: + if task_cancelling > 0: + self._reschedule_delayed_task_cancel(host_task, None) + else: + if current_task_scope.__cancel_called: + self._reschedule_delayed_task_cancel(host_task, current_task_scope.__cancellation_id()) + + return self.__cancelled_caught + + def __deliver_cancellation(self) -> None: + if self.__host_task is None: + # Scope not active. + return try: - await type(timeout_handle).__aexit__(timeout_handle, exc_type, exc_val, exc_tb) - except TimeoutError: - if self.__only_move_on: - return True - raise - else: - return None - finally: - delayed_task_cancel = None - try: - self.__current_handle_dict[current_task].popleft() - except LookupError: # pragma: no cover - pass - finally: - if not self.__already_delayed_cancellation: - if not self.__current_handle_dict.get(current_task): - delayed_task_cancel = self.__delayed_task_cancel_dict.pop(current_task, None) - if delayed_task_cancel is not None: - delayed_task_cancel.cancel() - del current_task, exc_val, exc_tb, self + self.__delayed_task_cancel_dict.pop(self.__host_task).handle.cancel() + except KeyError: + pass + self.__host_task.cancel(msg=self.__cancellation_id()) - def when(self) -> float: - deadline: float | None = self.__handle.when() - return deadline if deadline is not None else math.inf + def __cancellation_id(self) -> str: + return f"Cancelled by cancel scope {id(self):x}" - def reschedule(self, when: float) -> None: - return self.__handle.reschedule(self._cast_time(when)) + def cancel(self) -> None: + if not self.__cancel_called: + self.__cancel_called = True + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + self.__deliver_cancellation() - def expired(self) -> bool: - return self.__handle.expired() + def cancel_called(self) -> bool: + return self.__cancel_called - @staticmethod - def _cast_time(time_value: float) -> float | None: - assert time_value is not None # nosec assert_used - return time_value if time_value != math.inf else None + def cancelled_caught(self) -> bool: + return self.__cancelled_caught + + def when(self) -> float: + return self.__deadline + + def reschedule(self, when: float, /) -> None: + if math.isnan(when): + raise ValueError("deadline is NaN") + self.__deadline = max(when, 0) + if self.__timeout_handle: + self.__timeout_handle.cancel() + self.__timeout_handle = None + if self.__state is _ScopeState.ENTERED and not self.__cancel_called: + self.__timeout() + + def __timeout(self) -> None: + if self.__deadline != math.inf: + loop = asyncio.get_running_loop() + if loop.time() >= self.__deadline: + self.cancel() + else: + self.__timeout_handle = loop.call_at(self.__deadline, self.__timeout) + + @classmethod + def _current_task_scope(cls, task: asyncio.Task[Any]) -> CancelScope | None: + try: + return cls.__current_task_scope_dict[task][0] + except LookupError: + return None @classmethod - def _reschedule_delayed_task_cancel(cls, task: asyncio.Task[Any], cancel_msg: str | None) -> None: + def _reschedule_delayed_task_cancel(cls, task: asyncio.Task[Any], cancel_msg: str | None) -> asyncio.Handle: + if task in cls.__delayed_task_cancel_dict: + raise RuntimeError("CancelScope issue.") # pragma: no cover task_cancel_handle = task.get_loop().call_soon(cls.__cancel_task_unless_done, task, cancel_msg) - if cls.__current_handle_dict.get(task): - if task in cls.__delayed_task_cancel_dict: - cls.__delayed_task_cancel_dict[task].cancel() - cls.__delayed_task_cancel_dict[task] = task_cancel_handle - else: - assert task not in cls.__delayed_task_cancel_dict # nosec assert_used - cls.__delayed_task_cancel_dict[task] = task_cancel_handle - task.get_loop().call_soon(cls.__delayed_task_cancel_dict.pop, task, None) + cls.__delayed_task_cancel_dict[task] = _DelayedCancel(task_cancel_handle, cancel_msg) + task.get_loop().call_soon(cls.__delayed_task_cancel_dict.pop, task, None) + return task_cancel_handle @staticmethod def __cancel_task_unless_done(task: asyncio.Task[Any], cancel_msg: str | None) -> None: @@ -241,17 +307,84 @@ def __cancel_task_unless_done(task: asyncio.Task[Any], cancel_msg: str | None) - task.cancel(cancel_msg) -def timeout(delay: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout(TimeoutHandle._cast_time(delay))) +@final +class TaskUtils: + @staticmethod + def current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Task[Any]: + t: asyncio.Task[Any] | None = asyncio.current_task(loop=loop) + if t is None: + raise RuntimeError("This function should be called within a task.") + return t + + @classmethod + async def cancel_shielded_wait_asyncio_futures( + cls, + fs: Iterable[asyncio.Future[Any]], + *, + abort_func: Callable[[], bool] | None = None, + ) -> asyncio.Handle | None: + fs = set(fs) + current_task: asyncio.Task[Any] = cls.current_asyncio_task() + abort: bool | None = None + task_cancelled: bool = False + task_cancel_msg: str | None = None + + try: + _schedule_task_discard(fs) + while fs: + try: + await asyncio.wait(fs) + except asyncio.CancelledError as exc: + if abort is None: + if abort_func is None: + abort = False + else: + abort = bool(abort_func()) + if abort: + raise + task_cancelled = True + task_cancel_msg = _get_cancelled_error_message(exc) + + if task_cancelled: + return CancelScope._reschedule_delayed_task_cancel(current_task, task_cancel_msg) + return None + finally: + del current_task, fs, abort_func + task_cancel_msg = None + + @classmethod + async def cancel_shielded_coro_yield(cls) -> None: + current_task: asyncio.Task[Any] = cls.current_asyncio_task() + try: + await asyncio.sleep(0) + except asyncio.CancelledError as exc: + CancelScope._reschedule_delayed_task_cancel(current_task, _get_cancelled_error_message(exc)) + finally: + del current_task + @classmethod + async def cancel_shielded_await_task(cls, task: asyncio.Task[_T_co]) -> _T_co: + # This task must be unregistered in order not to be cancelled by runner at event loop shutdown + asyncio._unregister_task(task) -def timeout_at(deadline: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout_at(TimeoutHandle._cast_time(deadline))) + try: + current_task_cancel_handle = await cls.cancel_shielded_wait_asyncio_futures({task}) + if current_task_cancel_handle is not None and task.cancelled(): + current_task_cancel_handle.cancel() + return task.result() + finally: + del task -def move_on_after(delay: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout(TimeoutHandle._cast_time(delay)), only_move_on=True) +def _get_cancelled_error_message(exc: asyncio.CancelledError) -> str | None: + msg: str | None + if exc.args: + msg = exc.args[0] + else: + msg = None + return msg -def move_on_at(deadline: float) -> TimeoutHandle: - return TimeoutHandle(asyncio.timeout_at(TimeoutHandle._cast_time(deadline)), only_move_on=True) +def _schedule_task_discard(fs: set[asyncio.Future[Any]]) -> None: + for f in fs: + f.add_done_callback(fs.discard) diff --git a/src/easynetwork_asyncio/threads.py b/src/easynetwork_asyncio/threads.py index 05a16b11..a43dc2c3 100644 --- a/src/easynetwork_asyncio/threads.py +++ b/src/easynetwork_asyncio/threads.py @@ -22,12 +22,19 @@ import asyncio import concurrent.futures import contextvars -from collections.abc import Callable, Coroutine -from typing import Any, ParamSpec, TypeVar, final +import inspect +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, ParamSpec, Self, TypeVar, final from easynetwork.api_async.backend.abc import ThreadsPortal as AbstractThreadsPortal from easynetwork.api_async.backend.sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar -from easynetwork.tools._utils import transform_future_exception as _transform_future_exception +from easynetwork.tools._lock import ForkSafeLock +from easynetwork.tools._utils import exception_with_notes as _exception_with_notes + +from .tasks import TaskUtils + +if TYPE_CHECKING: + from types import TracebackType _P = ParamSpec("_P") _T = TypeVar("_T") @@ -35,78 +42,132 @@ @final class ThreadsPortal(AbstractThreadsPortal): - __slots__ = ("__loop",) + __slots__ = ("__loop", "__lock", "__task_group", "__call_soon_waiters") - def __init__(self, *, loop: asyncio.AbstractEventLoop | None = None) -> None: + def __init__(self) -> None: super().__init__() - if loop is None: - loop = asyncio.get_running_loop() - self.__loop: asyncio.AbstractEventLoop = loop + self.__loop: asyncio.AbstractEventLoop | None = None + self.__lock = ForkSafeLock() + self.__task_group: asyncio.TaskGroup = asyncio.TaskGroup() + self.__call_soon_waiters: set[asyncio.Future[None]] = set() + + async def __aenter__(self) -> Self: + if self.__loop is not None: + raise RuntimeError("ThreadsPortal entered twice.") + await self.__task_group.__aenter__() + self.__loop = asyncio.get_running_loop() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + try: + with self.__lock.get(): + self.__loop = None - def run_coroutine(self, coro_func: Callable[_P, Coroutine[Any, Any, _T]], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: - self.__check_running_loop() - return self.__get_result(self.__run_coroutine_soon(coro_func, *args, **kwargs)) + while self.__call_soon_waiters: + await TaskUtils.cancel_shielded_wait_asyncio_futures(self.__call_soon_waiters) + await self.__task_group.__aexit__(exc_type, exc_val, exc_tb) + finally: + del self, exc_val, exc_tb - def __run_coroutine_soon( + def run_coroutine_soon( self, - coro_func: Callable[_P, Coroutine[Any, Any, _T]], + coro_func: Callable[_P, Awaitable[_T]], /, *args: _P.args, **kwargs: _P.kwargs, ) -> concurrent.futures.Future[_T]: - coroutine = coro_func(*args, **kwargs) - if _sniffio_current_async_library_cvar is not None: - ctx = contextvars.copy_context() - ctx.run(_sniffio_current_async_library_cvar.set, "asyncio") - return ctx.run(asyncio.run_coroutine_threadsafe, coroutine, self.__loop) # type: ignore[arg-type] - - return asyncio.run_coroutine_threadsafe(coroutine, self.__loop) - - def run_sync(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: - self.__check_running_loop() - return self.__get_result(self.__run_sync_soon(func, *args, **kwargs)) - - def __run_sync_soon(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> concurrent.futures.Future[_T]: + def schedule_task() -> concurrent.futures.Future[_T]: + future: concurrent.futures.Future[_T] = concurrent.futures.Future() + + async def coroutine() -> None: + try: + result = await coro_func(*args, **kwargs) + except asyncio.CancelledError: + future.cancel() + future.set_running_or_notify_cancel() + raise + except BaseException as exc: + if future.set_running_or_notify_cancel(): + future.set_exception(exc) + if not isinstance(exc, Exception): + raise # pragma: no cover + else: + if future.set_running_or_notify_cancel(): + future.set_result(result) + + task = self.__task_group.create_task(coroutine()) + loop = task.get_loop() + with self.__lock.get(): + loop.call_soon(self.__register_waiter(self.__call_soon_waiters, loop).set_result, None) + + def on_fut_done(future: concurrent.futures.Future[_T]) -> None: + if future.cancelled(): + try: + self.run_sync(task.cancel) + except RuntimeError: + # on_fut_done() called from coroutine() + # or the portal is already shut down + pass + + future.add_done_callback(on_fut_done) + + return future + + return self.run_sync(schedule_task) + + def run_sync_soon(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> concurrent.futures.Future[_T]: def callback() -> None: + waiter.set_result(None) + if not future.set_running_or_notify_cancel(): + return try: result = func(*args, **kwargs) + if inspect.iscoroutine(result): + result.close() # Prevent ResourceWarnings + msg = "func is a coroutine function." + note = "You should use run_coroutine() or run_coroutine_soon() instead." + raise _exception_with_notes(TypeError(msg), note) except BaseException as exc: - future.set_exception(_transform_future_exception(exc)) + future.set_exception(exc) if isinstance(exc, (SystemExit, KeyboardInterrupt)): # pragma: no cover raise else: future.set_result(result) + with self.__lock.get(): + loop = self.__check_loop() + waiter = self.__register_waiter(self.__call_soon_waiters, loop) + ctx = contextvars.copy_context() if _sniffio_current_async_library_cvar is not None: ctx.run(_sniffio_current_async_library_cvar.set, "asyncio") future: concurrent.futures.Future[_T] = concurrent.futures.Future() - future.set_running_or_notify_cancel() - self.__loop.call_soon_threadsafe(callback, context=ctx) + loop.call_soon_threadsafe(callback, context=ctx) return future - @staticmethod - def __get_result(future: concurrent.futures.Future[_T]) -> _T: - try: - return future.result() - except concurrent.futures.CancelledError: - if not future.cancelled(): # raised from future.exception() - raise - raise asyncio.CancelledError() from None - finally: - del future - - def __check_running_loop(self) -> None: + def __check_loop(self) -> asyncio.AbstractEventLoop: + loop = self.__loop + if loop is None: + raise RuntimeError("ThreadsPortal not running.") try: running_loop = asyncio.get_running_loop() except RuntimeError: - return - if running_loop is self.__loop: - raise RuntimeError("must be called in a different OS thread") + return loop + if running_loop is loop: + raise RuntimeError("This function must be called in a different OS thread") + return loop - @property - def loop(self) -> asyncio.AbstractEventLoop: - return self.__loop + @staticmethod + def __register_waiter(waiters: set[asyncio.Future[None]], loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]: + waiter: asyncio.Future[None] = loop.create_future() + waiters.add(waiter) + waiter.add_done_callback(waiters.discard) + return waiter diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index 2246aa02..fe25826e 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -4,14 +4,20 @@ import contextvars import time from collections.abc import Awaitable, Callable -from concurrent.futures import CancelledError as FutureCancelledError, Future -from typing import Any +from concurrent.futures import CancelledError as FutureCancelledError, Future, wait as wait_concurrent_futures +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Literal from easynetwork.api_async.backend.factory import AsyncBackendFactory -from easynetwork_asyncio.backend import AsyncioBackend +from easynetwork_asyncio.backend import AsyncIOBackend import pytest +if TYPE_CHECKING: + from unittest.mock import AsyncMock + + from pytest_mock import MockerFixture + cvar_for_test: contextvars.ContextVar[str] = contextvars.ContextVar("cvar_for_test", default="") @@ -19,18 +25,18 @@ class TestAsyncioBackend: @pytest.fixture @staticmethod - def backend() -> AsyncioBackend: + def backend() -> AsyncIOBackend: backend = AsyncBackendFactory.new("asyncio") - assert isinstance(backend, AsyncioBackend) + assert isinstance(backend, AsyncIOBackend) return backend - async def test____use_asyncio_transport____True_by_default(self, backend: AsyncioBackend) -> None: + async def test____use_asyncio_transport____True_by_default(self, backend: AsyncIOBackend) -> None: assert backend.using_asyncio_transport() async def test____cancel_shielded_coro_yield____mute_cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: task: asyncio.Task[None] = event_loop.create_task(backend.cancel_shielded_coro_yield()) @@ -48,7 +54,7 @@ async def test____cancel_shielded_coro_yield____cancel_at_the_next_checkpoint( self, cancel_message: str | None, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: test_list: list[str] = [] @@ -81,7 +87,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____always_continue_on_cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: task: asyncio.Task[int] = event_loop.create_task(backend.ignore_cancellation(asyncio.sleep(0.5, 42))) @@ -96,7 +102,7 @@ async def test____ignore_cancellation____always_continue_on_cancellation( async def test____ignore_cancellation____task_does_not_appear_in_registered_tasks( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine() -> bool: task = asyncio.current_task() @@ -110,7 +116,7 @@ async def coroutine() -> bool: async def test____ignore_cancellation____coroutine_cancelled_itself( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def self_cancellation() -> None: task = asyncio.current_task() @@ -125,28 +131,43 @@ async def self_cancellation() -> None: await task assert task.cancelled() + @pytest.mark.xfail("sys.version_info < (3, 12)", reason="asyncio.Task.get_context() does not exist before Python 3.12") + async def test____ignore_cancellation____share_same_context_with_host_task( + self, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + await asyncio.sleep(0.1) + + cvar_for_test.set("after_in_coroutine") + + cvar_for_test.set("before_in_current_task") + await backend.ignore_cancellation(coroutine()) + + assert cvar_for_test.get() == "after_in_coroutine" + async def test____timeout____respected( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.timeout(1): + with backend.timeout(1): assert await asyncio.sleep(0.5, 42) == 42 async def test____timeout____timeout_error( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: with pytest.raises(TimeoutError): - async with backend.timeout(0.25): + with backend.timeout(0.25): await asyncio.sleep(0.5, 42) async def test____timeout____cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.timeout(0.25): + with backend.timeout(0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -158,27 +179,27 @@ async def coroutine() -> None: async def test____timeout_at____respected( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.timeout_at(event_loop.time() + 1): + with backend.timeout_at(event_loop.time() + 1): assert await asyncio.sleep(0.5, 42) == 42 async def test____timeout_at____timeout_error( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: with pytest.raises(TimeoutError): - async with backend.timeout_at(event_loop.time() + 0.25): + with backend.timeout_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) async def test____timeout_at____cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.timeout_at(event_loop.time() + 0.25): + with backend.timeout_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -189,29 +210,29 @@ async def coroutine() -> None: async def test____move_on_after____respected( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.move_on_after(1) as handle: + with backend.move_on_after(1) as scope: assert await asyncio.sleep(0.5, 42) == 42 - assert not handle.expired() + assert not scope.cancelled_caught() async def test____move_on_after____timeout_error( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.move_on_after(0.25) as handle: + with backend.move_on_after(0.25) as scope: await asyncio.sleep(0.5, 42) - assert handle.expired() + assert scope.cancelled_caught() async def test____move_on_after____cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.move_on_after(0.25): + with backend.move_on_after(0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -223,30 +244,30 @@ async def coroutine() -> None: async def test____move_on_at____respected( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.move_on_at(event_loop.time() + 1) as handle: + with backend.move_on_at(event_loop.time() + 1) as scope: assert await asyncio.sleep(0.5, 42) == 42 - assert not handle.expired() + assert not scope.cancelled_caught() async def test____move_on_at____timeout_error( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async with backend.move_on_at(event_loop.time() + 0.25) as handle: + with backend.move_on_at(event_loop.time() + 0.25) as scope: await asyncio.sleep(0.5, 42) - assert handle.expired() + assert scope.cancelled_caught() async def test____move_on_at____cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine() -> None: - async with backend.move_on_at(event_loop.time() + 0.25): + with backend.move_on_at(event_loop.time() + 0.25): await asyncio.sleep(0.5, 42) task = event_loop.create_task(coroutine()) @@ -258,7 +279,7 @@ async def coroutine() -> None: async def test____sleep_forever____sleep_until_cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: sleep_task = event_loop.create_task(backend.sleep_forever()) @@ -267,65 +288,134 @@ async def test____sleep_forever____sleep_until_cancellation( with pytest.raises(asyncio.CancelledError): await sleep_task - async def test____spawn_task____run_coroutine_in_background( + async def test____open_cancel_scope____unbound_cancel_scope____cancel_when_entering( self, - backend: AsyncioBackend, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - return await asyncio.sleep(0.5, value) + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None - task = backend.spawn_task(coroutine, 42) - await asyncio.sleep(1) - assert task.done() - assert await task.join() == 42 + scope = backend.open_cancel_scope() + scope.cancel() + assert scope.cancel_called() - async def test____spawn_task____task_cancellation( + assert current_task.cancelling() == 0 + await asyncio.sleep(0.1) + + with scope: + assert current_task.cancelling() == 1 + scope.cancel() + assert current_task.cancelling() == 1 + await backend.coro_yield() + + assert current_task.cancelling() == 0 + assert scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + + async def test____open_cancel_scope____unbound_cancel_scope____deadline_scheduled_when_entering( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - return await asyncio.sleep(0.5, value) + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None - task = backend.spawn_task(coroutine, 42) + scope = backend.open_cancel_scope() + scope.reschedule(event_loop.time() + 1) - event_loop.call_later(0.2, task.cancel) + assert current_task.cancelling() == 0 + await asyncio.sleep(0.5) - with pytest.raises(asyncio.CancelledError): - await task.join() + with backend.timeout(0.6): + with scope: + assert current_task.cancelling() == 0 + await backend.sleep_forever() - assert task.cancelled() + assert scope.cancelled_caught() + + await event_loop.create_task(coroutine()) - async def test____spawn_task____exception( + async def test____open_cancel_scope____overwrite_defined_deadline( self, - backend: AsyncioBackend, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, ) -> None: - async def coroutine(value: int) -> int: - await asyncio.sleep(0.1) - raise ZeroDivisionError + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None - task = backend.spawn_task(coroutine, 42) + with backend.move_on_after(1) as scope: + await backend.sleep(0.5) + scope.deadline += 1 + await backend.sleep(1) + del scope.deadline + assert scope.deadline == float("+inf") + await backend.sleep(1) - with pytest.raises(ZeroDivisionError): - await task.join() + assert not scope.cancelled_caught() - async def test____spawn_task____with_context( + await event_loop.create_task(coroutine()) + + async def test____open_cancel_scope____invalid_deadline( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - async def coroutine(value: str) -> None: - cvar_for_test.set(value) + with pytest.raises(ValueError): + _ = backend.open_cancel_scope(deadline=float("nan")) + + async def test____open_cancel_scope____context_reuse( + self, + backend: AsyncIOBackend, + ) -> None: + with backend.open_cancel_scope() as scope: + with pytest.raises(RuntimeError, match=r"^CancelScope entered twice$"): + with scope: + ... + + with pytest.raises(RuntimeError, match=r"^CancelScope entered twice$"): + with scope: + ... + + async def test____open_cancel_scope____context_exit_before_enter( + self, + backend: AsyncIOBackend, + ) -> None: + with pytest.raises(RuntimeError, match=r"^This cancel scope is not active$"), ExitStack() as stack: + stack.push(backend.open_cancel_scope()) + + async def test____open_cancel_scope____task_misnesting( + self, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> ExitStack: + stack = ExitStack() + stack.enter_context(backend.open_cancel_scope()) + return stack - cvar_for_test.set("something") - ctx = contextvars.copy_context() - task = backend.spawn_task(coroutine, "other", context=ctx) - await task.wait() - assert cvar_for_test.get() == "something" - assert ctx.run(cvar_for_test.get) == "other" + stack = await asyncio.create_task(coroutine()) + with pytest.raises(RuntimeError, match=r"^Attempted to exit cancel scope in a different task than it was entered in$"): + stack.close() + + async def test____open_cancel_scope____scope_misnesting( + self, + backend: AsyncIOBackend, + ) -> None: + stack = ExitStack() + stack.enter_context(backend.open_cancel_scope()) + with backend.open_cancel_scope(): + with pytest.raises( + RuntimeError, match=r"^Attempted to exit a cancel scope that isn't the current tasks's current cancel scope$" + ): + stack.close() + stack.pop_all() async def test____create_task_group____task_pool( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine(value: int) -> int: return await asyncio.sleep(0.5, value) @@ -356,7 +446,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____task_cancellation( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine(value: int) -> int: return await asyncio.sleep(0.5, value) @@ -382,10 +472,12 @@ async def coroutine(value: int) -> int: # Tasks cannot be cancelled twice assert not task_42.cancel() + @pytest.mark.parametrize("join_method", ["join", "join_or_cancel"]) async def test____create_task_group____task_join_cancel_shielding( self, + join_method: Literal["join", "join_or_cancel"], event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine(value: int) -> int: return await asyncio.sleep(0.5, value) @@ -393,19 +485,28 @@ async def coroutine(value: int) -> int: async with backend.create_task_group() as task_group: inner_task = task_group.start_soon(coroutine, 42) - outer_task = event_loop.create_task(inner_task.join()) + match join_method: + case "join": + outer_task = event_loop.create_task(inner_task.join()) + case "join_or_cancel": + outer_task = event_loop.create_task(inner_task.join_or_cancel()) + case _: + pytest.fail("invalid argument") event_loop.call_later(0.2, outer_task.cancel) with pytest.raises(asyncio.CancelledError): await outer_task assert outer_task.cancelled() - assert not inner_task.cancelled() - assert await inner_task.join() == 42 + if join_method == "join_or_cancel": + assert inner_task.cancelled() + else: + assert not inner_task.cancelled() + assert await inner_task.join() == 42 async def test____create_task_group____start_soon_with_context( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def coroutine(value: str) -> None: cvar_for_test.set(value) @@ -421,7 +522,7 @@ async def coroutine(value: str) -> None: async def test____wait_future____wait_until_done( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() event_loop.call_later(0.5, future.set_result, 42) @@ -433,7 +534,7 @@ async def test____wait_future____cancel_future_if_task_is_cancelled( self, future_running: str | None, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() if future_running == "before": @@ -464,7 +565,7 @@ async def test____wait_future____cancel_future_if_task_is_cancelled( async def test____wait_future____future_is_cancelled( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() task = event_loop.create_task(backend.wait_future(future)) @@ -477,7 +578,7 @@ async def test____wait_future____future_is_cancelled( async def test____wait_future____already_done( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() future.set_result(42) @@ -486,7 +587,7 @@ async def test____wait_future____already_done( async def test____wait_future____already_cancelled( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() future.cancel() @@ -497,7 +598,7 @@ async def test____wait_future____already_cancelled( async def test____wait_future____already_cancelled____task_cancelled_too( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() future.cancel() @@ -511,7 +612,7 @@ async def test____wait_future____already_cancelled____task_cancelled_too( async def test____wait_future____task_cancellation_prevails_over_future_cancellation( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: future: Future[int] = Future() @@ -530,7 +631,7 @@ async def test____wait_future____task_cancellation_prevails_over_future_cancella async def test____run_in_thread____cannot_be_cancelled( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: task = event_loop.create_task(backend.run_in_thread(time.sleep, 0.5)) event_loop.call_later(0.1, task.cancel) @@ -543,7 +644,7 @@ async def test____run_in_thread____cannot_be_cancelled( assert not task.cancelled() @pytest.mark.feature_sniffio - async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncioBackend) -> None: + async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncIOBackend) -> None: import sniffio sniffio.current_async_library_cvar.set("asyncio") @@ -560,7 +661,7 @@ def callback() -> str | None: async def test____create_threads_portal____run_coroutine_from_thread( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: threads_portal = backend.create_threads_portal() @@ -568,23 +669,28 @@ async def coroutine(value: int) -> int: assert asyncio.get_running_loop() is event_loop return await asyncio.sleep(0.5, value) - with pytest.raises(RuntimeError): - threads_portal.run_coroutine(coroutine, 42) - def thread() -> int: with pytest.raises(RuntimeError): asyncio.get_running_loop() return threads_portal.run_coroutine(coroutine, 42) - assert await backend.run_in_thread(thread) == 42 + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async with threads_portal: + with pytest.raises(RuntimeError): + threads_portal.run_coroutine(coroutine, 42) + + assert await backend.run_in_thread(thread) == 42 + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) async def test____create_threads_portal____run_coroutine_from_thread____can_be_called_from_other_event_loop( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() - async def coroutine(value: int) -> int: assert asyncio.get_running_loop() is event_loop return await asyncio.sleep(0.5, value) @@ -596,13 +702,13 @@ async def main() -> int: return backend.bootstrap(main) - assert await backend.run_in_thread(thread) == 42 + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 42 async def test____create_threads_portal____run_coroutine_from_thread____exception_raised( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() expected_exception = OSError("Why not?") async def coroutine(value: int) -> int: @@ -611,17 +717,16 @@ async def coroutine(value: int) -> int: def thread() -> int: return threads_portal.run_coroutine(coroutine, 42) - with pytest.raises(OSError) as exc_info: - await backend.run_in_thread(thread) + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(OSError) as exc_info: + await backend.run_in_thread(thread) assert exc_info.value is expected_exception async def test____create_threads_portal____run_coroutine_from_thread____coroutine_cancelled( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() - async def coroutine(value: int) -> int: task = asyncio.current_task() assert task is not None @@ -632,33 +737,32 @@ async def coroutine(value: int) -> int: def thread() -> int: return threads_portal.run_coroutine(coroutine, 42) - with pytest.raises(asyncio.CancelledError): - await backend.run_in_thread(thread) + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(asyncio.CancelledError): # asyncio do the job to convert the concurrent.futures.CancelledError + await backend.run_in_thread(thread) async def test____create_threads_portal____run_coroutine_from_thread____explicit_concurrent_future_Cancelled( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() - async def coroutine(value: int) -> int: raise FutureCancelledError() def thread() -> int: - with pytest.raises(asyncio.CancelledError): # asyncio do the job to convert the concurrent.futures.CancelledError + with pytest.raises(FutureCancelledError): return threads_portal.run_coroutine(coroutine, 42) return 54 - assert await backend.run_in_thread(thread) == 54 + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 54 @pytest.mark.feature_sniffio async def test____create_threads_portal____run_coroutine_from_thread____sniffio_contextvar_reset( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: import sniffio - threads_portal = backend.create_threads_portal() sniffio.current_async_library_cvar.set("main") async def coroutine() -> str | None: @@ -667,8 +771,9 @@ async def coroutine() -> str | None: def thread() -> str | None: return threads_portal.run_coroutine(coroutine) - cvar_inner = await backend.run_in_thread(thread) - cvar_outer = sniffio.current_async_library_cvar.get() + async with backend.create_threads_portal() as threads_portal: + cvar_inner = await backend.run_in_thread(thread) + cvar_outer = sniffio.current_async_library_cvar.get() assert cvar_inner == "asyncio" assert cvar_outer == "main" @@ -676,7 +781,7 @@ def thread() -> str | None: async def test____create_threads_portal____run_sync_from_thread_in_event_loop( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: threads_portal = backend.create_threads_portal() @@ -684,23 +789,28 @@ def not_threadsafe_func(value: int) -> int: assert asyncio.get_running_loop() is event_loop return value - with pytest.raises(RuntimeError): - threads_portal.run_sync(not_threadsafe_func, 42) - def thread() -> int: with pytest.raises(RuntimeError): asyncio.get_running_loop() return threads_portal.run_sync(not_threadsafe_func, 42) - assert await backend.run_in_thread(thread) == 42 + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async with threads_portal: + with pytest.raises(RuntimeError): + threads_portal.run_sync(not_threadsafe_func, 42) + + assert await backend.run_in_thread(thread) == 42 + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) async def test____create_threads_portal____run_sync_from_thread_in_event_loop____can_be_called_from_other_event_loop( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() - def not_threadsafe_func(value: int) -> int: assert asyncio.get_running_loop() is event_loop return value @@ -712,13 +822,13 @@ async def main() -> int: return backend.bootstrap(main) - assert await backend.run_in_thread(thread) == 42 + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 42 async def test____create_threads_portal____run_sync_from_thread_in_event_loop____exception_raised( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() expected_exception = OSError("Why not?") def not_threadsafe_func(value: int) -> int: @@ -727,17 +837,16 @@ def not_threadsafe_func(value: int) -> int: def thread() -> int: return threads_portal.run_sync(not_threadsafe_func, 42) - with pytest.raises(OSError) as exc_info: - await backend.run_in_thread(thread) + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(OSError) as exc_info: + await backend.run_in_thread(thread) assert exc_info.value is expected_exception async def test____create_threads_portal____run_sync_from_thread_in_event_loop____explicit_concurrent_future_Cancelled( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: - threads_portal = backend.create_threads_portal() - def not_threadsafe_func(value: int) -> int: raise FutureCancelledError() @@ -746,16 +855,32 @@ def thread() -> int: return threads_portal.run_sync(not_threadsafe_func, 42) return 54 - assert await backend.run_in_thread(thread) == 54 + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 54 + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____async_function_given( + self, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + raise AssertionError("Should not be called") + + def thread() -> None: + with pytest.raises(TypeError, match=r"^func is a coroutine function.$") as exc_info: + _ = threads_portal.run_sync(coroutine) + + assert exc_info.value.__notes__ == ["You should use run_coroutine() or run_coroutine_soon() instead."] + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) @pytest.mark.feature_sniffio async def test____create_threads_portal____run_sync_from_thread_in_event_loop____sniffio_contextvar_reset( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: import sniffio - threads_portal = backend.create_threads_portal() sniffio.current_async_library_cvar.set("main") def callback() -> str | None: @@ -764,27 +889,155 @@ def callback() -> str | None: def thread() -> str | None: return threads_portal.run_sync(callback) - cvar_inner = await backend.run_in_thread(thread) - cvar_outer = sniffio.current_async_library_cvar.get() + async with backend.create_threads_portal() as threads_portal: + cvar_inner = await backend.run_in_thread(thread) + cvar_outer = sniffio.current_async_library_cvar.get() assert cvar_inner == "asyncio" assert cvar_outer == "main" + async def test____create_threads_portal____run_sync_soon____future_cancelled_before_call( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + func_stub = mocker.stub() + + def thread() -> None: + event_loop.call_soon_threadsafe(time.sleep, 1) # Drastically slow down event loop + + future = threads_portal.run_sync_soon(func_stub, 42) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + wait_concurrent_futures({future}, timeout=5) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + func_stub.assert_not_called() + + async def test____create_threads_portal____run_coroutine_soon____future_cancelled( + self, + backend: AsyncIOBackend, + ) -> None: + def thread() -> None: + future = threads_portal.run_coroutine_soon(asyncio.sleep, 1) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + wait_concurrent_futures({future}, timeout=0.2) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + @pytest.mark.parametrize("value", [42, ValueError("Not caught")], ids=repr) + async def test____create_threads_portal____run_coroutine_soon____future_cancelled____cancellation_ignored( + self, + value: int | Exception, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + cancellation_ignored = mocker.stub() + + async def coroutine() -> int: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + pass + await asyncio.sleep(0) + cancellation_ignored() + if isinstance(value, Exception): + raise value + return value + + def thread() -> None: + future = threads_portal.run_coroutine_soon(coroutine) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + wait_concurrent_futures({future}, timeout=0.2) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + cancellation_ignored.assert_called_once() + + async def test____create_threads_portal____context_exit____wait_scheduled_call_soon( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + func_stub = mocker.stub() + func_stub.return_value = "Hello, world!" + + def thread() -> None: + future = threads_portal.run_sync_soon(func_stub, 42) + assert future.result(timeout=1) == "Hello, world!" + + async with backend.create_threads_portal() as threads_portal: + task = event_loop.create_task(backend.run_in_thread(thread)) + event_loop.call_soon(time.sleep, 0.5) + await asyncio.sleep(0) + + func_stub.assert_called_once_with(42) + await task + + async def test____create_threads_portal____context_exit____wait_scheduled_call_soon_for_coroutine( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + coro_stub: AsyncMock = mocker.async_stub() + coro_stub.return_value = "Hello, world!" + + def thread() -> None: + future = threads_portal.run_coroutine_soon(coro_stub, 42) + assert future.result(timeout=1) == "Hello, world!" + + async with backend.create_threads_portal() as threads_portal: + task = event_loop.create_task(backend.run_in_thread(thread)) + event_loop.call_soon(time.sleep, 0.5) + await asyncio.sleep(0) + + coro_stub.assert_awaited_once_with(42) + await task + + async def test____create_threads_portal____entered_twice( + self, + backend: AsyncIOBackend, + ) -> None: + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(RuntimeError, match=r"ThreadsPortal entered twice\."): + await threads_portal.__aenter__() + @pytest.mark.asyncio class TestAsyncioBackendShieldedCancellation: @pytest.fixture @staticmethod - def backend() -> AsyncioBackend: + def backend() -> AsyncIOBackend: backend = AsyncBackendFactory.new("asyncio") - assert isinstance(backend, AsyncioBackend) + assert isinstance(backend, AsyncIOBackend) return backend @pytest.fixture(params=["cancel_shielded_coro_yield", "ignore_cancellation", "run_in_thread", "wait_future"]) @staticmethod def cancel_shielded_coroutine( request: pytest.FixtureRequest, - backend: AsyncioBackend, + backend: AsyncIOBackend, event_loop: asyncio.AbstractEventLoop, ) -> Callable[[], Awaitable[Any]]: match getattr(request, "param"): @@ -810,16 +1063,17 @@ async def test____cancel_shielded_coroutine____do_not_cancel_at_timeout_end( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: checkpoints: list[str] = [] async def coroutine(value: int) -> int: - async with backend.timeout(0) as handle: + with backend.timeout(0) as scope: await cancel_shielded_coroutine() checkpoints.append("cancel_shielded_coroutine") - assert handle.expired() # Context manager did not raise but effectively tried to cancel the task + assert scope.cancel_called() + assert not scope.cancelled_caught() await backend.coro_yield() checkpoints.append("coro_yield") return value @@ -833,7 +1087,7 @@ async def test____cancel_shielded_coroutine____cancel_at_timeout_end_if_nested( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: checkpoints: list[str] = [] @@ -841,9 +1095,9 @@ async def coroutine(value: int) -> int: current_task = asyncio.current_task() assert current_task is not None - async with backend.move_on_after(0) as handle: - async with backend.timeout(0): - async with backend.timeout(0): + with backend.move_on_after(0) as scope: + with backend.timeout(0): + with backend.timeout(0): await cancel_shielded_coroutine() checkpoints.append("inner_cancel_shielded_coroutine") assert current_task.cancelling() == 3 @@ -859,7 +1113,8 @@ async def coroutine(value: int) -> int: assert current_task.cancelling() == 0 - assert handle.expired() + assert scope.cancel_called() + assert scope.cancelled_caught() await backend.coro_yield() checkpoints.append("coro_yield") return value @@ -873,7 +1128,7 @@ async def test____timeout____cancel_at_timeout_end_if_task_cancellation_were_alr self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: checkpoints: list[str] = [] @@ -884,7 +1139,7 @@ async def coroutine(value: int) -> int: await cancel_shielded_coroutine() checkpoints.append("cancel_shielded_coroutine") - async with backend.timeout(0): + with backend.timeout(0): await cancel_shielded_coroutine() checkpoints.append("inner_cancel_shielded_coroutine") assert current_task.cancelling() == 2 @@ -901,10 +1156,93 @@ async def coroutine(value: int) -> int: await task assert checkpoints == ["cancel_shielded_coroutine", "inner_cancel_shielded_coroutine"] + async def test____cancel_shielded_coroutine____cancel_at_timeout_end_if_task_cancellation_does_not_come_from_scope( + self, + cancel_shielded_coroutine: Callable[[], Awaitable[Any]], + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + checkpoints: list[str] = [] + + async def coroutine(value: int) -> int: + current_task = asyncio.current_task() + assert current_task is not None + + with backend.open_cancel_scope(): + with backend.open_cancel_scope(): + await cancel_shielded_coroutine() + checkpoints.append("cancel_shielded_coroutine") + assert current_task.cancelling() == 1 + + assert current_task.cancelling() == 1 + await backend.coro_yield() + checkpoints.append("should_not_be_here") + return value + + task = event_loop.create_task(coroutine(42)) + event_loop.call_soon(task.cancel) + + with pytest.raises(asyncio.CancelledError): + await task + assert checkpoints == ["cancel_shielded_coroutine"] + + async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_1( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + inner_scope = backend.open_cancel_scope() + with backend.move_on_after(0.5) as outer_scope: + with inner_scope: + await backend.ignore_cancellation(backend.sleep(1)) + inner_scope.cancel() + await backend.coro_yield() + + assert outer_scope.cancel_called() + assert inner_scope.cancel_called() + + assert not outer_scope.cancelled_caught() + assert inner_scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + + async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_2( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + outer_scope = backend.move_on_after(1.5) + inner_scope = backend.move_on_after(0.5) + with outer_scope: + with inner_scope: + await backend.ignore_cancellation(backend.sleep(1)) + assert not inner_scope.cancelled_caught() + try: + await backend.coro_yield() + except asyncio.CancelledError: + pytest.fail("Cancelled") + await backend.sleep(1) + + assert outer_scope.cancel_called() + assert inner_scope.cancel_called() + + assert outer_scope.cancelled_caught() + assert not inner_scope.cancelled_caught() + + await event_loop.create_task(coroutine()) + async def test____ignore_cancellation____do_not_reschedule_if_inner_task_cancelled_itself( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: async def self_cancellation() -> None: task = asyncio.current_task() diff --git a/tests/functional_test/test_async/test_backend/test_backend_factory.py b/tests/functional_test/test_async/test_backend/test_backend_factory.py index d0c89a95..f5853650 100644 --- a/tests/functional_test/test_async/test_backend/test_backend_factory.py +++ b/tests/functional_test/test_async/test_backend/test_backend_factory.py @@ -19,18 +19,18 @@ def test____get_available_backends____imports_asyncio_backend(self) -> None: assert AsyncBackendFactory.get_available_backends() == frozenset({"asyncio"}) def test____get_all_backends____imports_asyncio_backend(self) -> None: - from easynetwork_asyncio import AsyncioBackend + from easynetwork_asyncio import AsyncIOBackend - assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncioBackend} + assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncIOBackend} def test____get_default_backend____returns_asyncio_backend(self) -> None: - from easynetwork_asyncio import AsyncioBackend + from easynetwork_asyncio import AsyncIOBackend - assert AsyncBackendFactory.get_default_backend(guess_current_async_library=False) is AsyncioBackend + assert AsyncBackendFactory.get_default_backend(guess_current_async_library=False) is AsyncIOBackend def test____new____returns_asyncio_backend_instance(self) -> None: - from easynetwork_asyncio import AsyncioBackend + from easynetwork_asyncio import AsyncIOBackend backend = AsyncBackendFactory.new("asyncio") - assert isinstance(backend, AsyncioBackend) + assert isinstance(backend, AsyncIOBackend) diff --git a/tests/functional_test/test_async/test_backend/test_futures.py b/tests/functional_test/test_async/test_backend/test_futures.py index e26eb656..84474cd5 100644 --- a/tests/functional_test/test_async/test_backend/test_futures.py +++ b/tests/functional_test/test_async/test_backend/test_futures.py @@ -4,6 +4,7 @@ import concurrent.futures import time from collections.abc import AsyncIterator +from typing import Any from easynetwork.api_async.backend.futures import AsyncExecutor @@ -68,6 +69,53 @@ def callback() -> str | None: assert cvar_inner is None assert cvar_outer == "asyncio" + async def test____map____schedule_many_calls( + self, + executor: AsyncExecutor, + ) -> None: + def thread_fn(a: int, b: int, c: int) -> tuple[int, int, int]: + return a, b, c + + results = [v async for v in executor.map(thread_fn, (1, 2, 3), (4, 5, 6), (7, 8, 9))] + + assert results == [(1, 4, 7), (2, 5, 8), (3, 6, 9)] + + async def test____map____early_schedule( + self, + executor: AsyncExecutor, + ) -> None: + def thread_fn(delay: float) -> int: + time.sleep(delay) + return 42 + + iterator = executor.map(thread_fn, (0.5, 0.75, 0.25)) + + executor.shutdown_nowait() + await asyncio.sleep(1) + + async with asyncio.timeout(0): + results = [v async for v in iterator] + + assert results == [42, 42, 42] + + @pytest.mark.feature_sniffio + async def test____map____sniffio_contextvar_reset( + self, + executor: AsyncExecutor, + ) -> None: + import sniffio + + sniffio.current_async_library_cvar.set("asyncio") + + def callback(*args: Any) -> str | None: + return sniffio.current_async_library_cvar.get() + + cvars_inner = [v async for v in executor.map(callback, (1, 2, 3))] + cvar_outer = sniffio.current_async_library_cvar.get() + + assert cvars_inner == [None, None, None] + assert cvar_outer == "asyncio" + async def test____shutdown____idempotent( self, executor: AsyncExecutor, diff --git a/tests/functional_test/test_async/test_backend/test_tasks.py b/tests/functional_test/test_async/test_backend/test_tasks.py deleted file mode 100644 index 5414ec31..00000000 --- a/tests/functional_test/test_async/test_backend/test_tasks.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, Any - -from easynetwork.api_async.backend.abc import AsyncBackend -from easynetwork.api_async.backend.factory import AsyncBackendFactory -from easynetwork.api_async.backend.tasks import SingleTaskRunner - -import pytest - -if TYPE_CHECKING: - from pytest_mock import MockerFixture - - -@pytest.mark.asyncio -class TestSingleTaskRunner: - @pytest.fixture - @staticmethod - def backend() -> AsyncBackend: - return AsyncBackendFactory.new("asyncio") - - async def test____run____run_task_once( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, return_value=mocker.sentinel.task_result) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 1, 2, 3, arg1=1, arg2=2, arg3=3) - - result1 = await runner.run() - result2 = await runner.run() - - assert result1 is mocker.sentinel.task_result - assert result2 is mocker.sentinel.task_result - coro_func.assert_awaited_once_with(1, 2, 3, arg1=1, arg2=2, arg3=3) - - async def test____run____early_cancel( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, return_value=mocker.sentinel.task_result) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 1, 2, 3, arg1=1, arg2=2, arg3=3) - - assert runner.cancel() - assert runner.cancel() # Cancel twice does nothing - - with pytest.raises(asyncio.CancelledError): - await runner.run() - - coro_func.assert_not_awaited() - coro_func.assert_not_called() - - async def test____run____cancel_while_running(self, backend: AsyncBackend) -> None: - async def coro_func(value: int) -> int: - return await asyncio.sleep(3600, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - await asyncio.sleep(0) - assert not task.done() - - assert runner.cancel() - assert runner.cancel() # Cancel twice does nothing - - with pytest.raises(asyncio.CancelledError): - await runner.run() - - assert task.cancelled() - - async def test____run____unhandled_exceptions( - self, - backend: AsyncBackend, - mocker: MockerFixture, - ) -> None: - my_exc = OSError() - coro_func = mocker.AsyncMock(spec=lambda *args, **kwargs: None, side_effect=[my_exc]) - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func) - - with pytest.raises(OSError) as exc_info_run_1: - _ = await runner.run() - with pytest.raises(OSError) as exc_info_run_2: - _ = await runner.run() - - assert exc_info_run_1.value is my_exc - assert exc_info_run_2.value is my_exc - - async def test____run____waiting_task_is_cancelled(self, backend: AsyncBackend) -> None: - inner_task: list[asyncio.Task[int] | None] = [] - - async def coro_func(value: int) -> int: - inner_task.append(asyncio.current_task()) - return await asyncio.sleep(3600, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - await asyncio.sleep(0.1) - - task.cancel() - - with pytest.raises(asyncio.CancelledError): - await task - - assert inner_task[0] is not None and inner_task[0].cancelled() - - async def test____run____waiting_task_is_cancelled____not_the_first_runner(self, backend: AsyncBackend) -> None: - async def coro_func(value: int) -> int: - return await asyncio.sleep(0.5, value) - - runner: SingleTaskRunner[Any] = SingleTaskRunner(backend, coro_func, 42) - - task = asyncio.create_task(runner.run()) - task_2 = asyncio.create_task(runner.run()) - await asyncio.sleep(0.1) - - task_2.cancel() - - assert await task == 42 - with pytest.raises(asyncio.CancelledError): - await task_2 diff --git a/tests/functional_test/test_communication/test_async/test_server/base.py b/tests/functional_test/test_communication/test_async/test_server/base.py index 6c8c90ac..738c7d6c 100644 --- a/tests/functional_test/test_communication/test_async/test_server/base.py +++ b/tests/functional_test/test_communication/test_async/test_server/base.py @@ -67,6 +67,22 @@ async def test____serve_forever____shutdown_during_setup( await server.shutdown() assert event.is_set() + async def test____serve_forever____server_close_during_setup( + self, + server: AbstractAsyncNetworkServer, + ) -> None: + event = asyncio.Event() + server_task = None + with pytest.raises(ExceptionGroup): + async with asyncio.TaskGroup() as tg: + server_task = tg.create_task(server.serve_forever(is_up_event=event)) + await asyncio.sleep(0) + assert not event.is_set() + await server.server_close() + assert not event.is_set() + assert server_task is not None + assert isinstance(server_task.exception(), ServerClosedError) + async def test____serve_forever____without_is_up_event( self, server: AbstractAsyncNetworkServer, diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 11b01b33..277f4c59 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -23,7 +23,7 @@ from easynetwork.protocol import StreamProtocol from easynetwork.tools.socket import SocketAddress, enable_socket_linger from easynetwork_asyncio._utils import create_connection -from easynetwork_asyncio.backend import AsyncioBackend +from easynetwork_asyncio.backend import AsyncIOBackend from easynetwork_asyncio.stream.listener import ListenerSocketAdapter import pytest @@ -32,7 +32,7 @@ from .base import BaseTestAsyncServer -class NoListenerErrorBackend(AsyncioBackend): +class NoListenerErrorBackend(AsyncIOBackend): async def create_tcp_listeners( self, host: str | Sequence[str] | None, @@ -167,7 +167,7 @@ async def handle(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, s await client.send_packet(request) try: with pytest.raises(TimeoutError): - async with self.backend.timeout(self.request_timeout): + with self.backend.timeout(self.request_timeout): yield await client.send_packet("successfully timed out") finally: @@ -195,7 +195,7 @@ async def on_connection(self, client: AsyncStreamClient[str]) -> AsyncGenerator[ if self.bypass_handshake: return try: - async with self.backend.timeout(1): + with self.backend.timeout(1): password = yield except TimeoutError: await client.send_packet("timeout error") diff --git a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py index 6ecd0ca1..dbdc25d8 100644 --- a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py +++ b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio import contextlib import threading import time -from collections.abc import AsyncGenerator, Callable, Iterator +from collections.abc import AsyncGenerator, Iterator from easynetwork.api_async.server.abc import AbstractAsyncNetworkServer from easynetwork.api_async.server.handler import AsyncBaseClientInterface, AsyncDatagramRequestHandler, AsyncStreamRequestHandler @@ -97,29 +96,11 @@ def test____server_thread____several_join( start_server.join() -def custom_asyncio_runner() -> asyncio.Runner: - return asyncio.Runner(loop_factory=asyncio.new_event_loop) - - class TestStandaloneTCPNetworkServer(BaseTestStandaloneNetworkServer): - @pytest.fixture(params=[None, custom_asyncio_runner]) - @staticmethod - def runner_factory(request: pytest.FixtureRequest) -> Callable[[], asyncio.Runner] | None: - return getattr(request, "param", None) - @pytest.fixture @staticmethod - def server( - stream_protocol: StreamProtocol[str, str], - runner_factory: Callable[[], asyncio.Runner] | None, - ) -> StandaloneTCPNetworkServer[str, str]: - return StandaloneTCPNetworkServer( - None, - 0, - stream_protocol, - EchoRequestHandler(), - backend_kwargs={"runner_factory": runner_factory}, - ) + def server(stream_protocol: StreamProtocol[str, str]) -> StandaloneTCPNetworkServer[str, str]: + return StandaloneTCPNetworkServer(None, 0, stream_protocol, EchoRequestHandler()) def test____dunder_init____invalid_backend(self, stream_protocol: StreamProtocol[str, str]) -> None: with pytest.raises(ValueError, match=r"^You must explicitly give a backend name or instance$"): @@ -131,7 +112,6 @@ def test____dunder_init____invalid_backend(self, stream_protocol: StreamProtocol backend=None, # type: ignore[arg-type] ) - @pytest.mark.parametrize("runner_factory", [None], indirect=True) def test____serve_forever____serve_several_times(self, server: StandaloneTCPNetworkServer[str, str]) -> None: with server: for _ in range(3): @@ -169,24 +149,10 @@ def test____logger_property____exposed(self, server: StandaloneTCPNetworkServer[ class TestStandaloneUDPNetworkServer(BaseTestStandaloneNetworkServer): - @pytest.fixture(params=[None, custom_asyncio_runner]) - @staticmethod - def runner_factory(request: pytest.FixtureRequest) -> Callable[[], asyncio.Runner] | None: - return getattr(request, "param", None) - @pytest.fixture @staticmethod - def server( - datagram_protocol: DatagramProtocol[str, str], - runner_factory: Callable[[], asyncio.Runner] | None, - ) -> StandaloneUDPNetworkServer[str, str]: - return StandaloneUDPNetworkServer( - "localhost", - 0, - datagram_protocol, - EchoRequestHandler(), - backend_kwargs={"runner_factory": runner_factory}, - ) + def server(datagram_protocol: DatagramProtocol[str, str]) -> StandaloneUDPNetworkServer[str, str]: + return StandaloneUDPNetworkServer("localhost", 0, datagram_protocol, EchoRequestHandler()) def test____dunder_init____invalid_backend(self, datagram_protocol: DatagramProtocol[str, str]) -> None: with pytest.raises(ValueError, match=r"^You must explicitly give a backend name or instance$"): @@ -198,7 +164,6 @@ def test____dunder_init____invalid_backend(self, datagram_protocol: DatagramProt backend=None, # type: ignore[arg-type] ) - @pytest.mark.parametrize("runner_factory", [None], indirect=True) def test____serve_forever____serve_several_times(self, server: StandaloneUDPNetworkServer[str, str]) -> None: with server: for _ in range(3): diff --git a/tests/unit_test/test_async/base.py b/tests/unit_test/test_async/base.py deleted file mode 100644 index f0dfa208..00000000 --- a/tests/unit_test/test_async/base.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from socket import AF_INET6 -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from unittest.mock import MagicMock - -from ..base import BaseTestSocket - - -class BaseTestAsyncSocketAdapter(BaseTestSocket): - @classmethod - def set_local_address_to_async_socket_adapter_mock( - cls, - mock_async_socket_adapter: MagicMock, - socket_family: int, - address: tuple[str, int] | None, - ) -> None: - if address is None: - if socket_family == AF_INET6: - address = ("::", 0) - else: - address = ("0.0.0.0", 0) - mock_async_socket_adapter.get_local_address.return_value = cls.get_resolved_addr_format(address, socket_family) - - @classmethod - def set_remote_address_to_async_socket_adapter_mock( - cls, - mock_async_socket_adapter: MagicMock, - socket_family: int, - address: tuple[str, int], - ) -> None: - mock_async_socket_adapter.get_remote_address.return_value = cls.get_resolved_addr_format(address, socket_family) diff --git a/tests/unit_test/test_async/conftest.py b/tests/unit_test/test_async/conftest.py index ced23fa5..8a08211c 100644 --- a/tests/unit_test/test_async/conftest.py +++ b/tests/unit_test/test_async/conftest.py @@ -4,12 +4,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from easynetwork.api_async.backend.abc import ( - AsyncBackend, - AsyncDatagramSocketAdapter, - AsyncHalfCloseableStreamSocketAdapter, - AsyncStreamSocketAdapter, -) +from easynetwork.api_async.backend.abc import AsyncBackend, AsyncDatagramSocketAdapter, AsyncStreamSocketAdapter import pytest @@ -37,14 +32,13 @@ def fake_cancellation_cls() -> type[BaseException]: @pytest.fixture def mock_backend(fake_cancellation_cls: type[BaseException], mocker: MockerFixture) -> MagicMock: - from easynetwork_asyncio.tasks import SystemTask, TaskGroup + from easynetwork_asyncio.tasks import TaskGroup from .._utils import AsyncDummyLock mock_backend = mocker.NonCallableMagicMock(spec=AsyncBackend) mock_backend.get_cancelled_exc_class.return_value = fake_cancellation_cls - mock_backend.spawn_task = lambda coro_func, *args, **kwargs: SystemTask(coro_func(*args), **kwargs) mock_backend.create_lock = AsyncDummyLock mock_backend.create_event = asyncio.Event mock_backend.create_task_group = TaskGroup @@ -53,17 +47,9 @@ def mock_backend(fake_cancellation_cls: type[BaseException], mocker: MockerFixtu @pytest.fixture -def mock_stream_socket_adapter_factory(request: pytest.FixtureRequest, mocker: MockerFixture) -> Callable[[], MagicMock]: - param = getattr(request, "param", None) - assert param in ("eof_support", None) - - eof_support: bool = param == "eof_support" - +def mock_stream_socket_adapter_factory(mocker: MockerFixture) -> Callable[[], MagicMock]: def factory() -> MagicMock: - if eof_support: - mock = mocker.NonCallableMagicMock(spec=AsyncHalfCloseableStreamSocketAdapter) - else: - mock = mocker.NonCallableMagicMock(spec=AsyncStreamSocketAdapter) + mock = mocker.NonCallableMagicMock(spec=AsyncStreamSocketAdapter) mock.sendall_fromiter = mocker.MagicMock(side_effect=lambda iterable_of_data: mock.sendall(b"".join(iterable_of_data))) mock.is_closing.return_value = False return mock diff --git a/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py b/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py index 48ee61d3..d1c85cef 100644 --- a/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py +++ b/tests/unit_test/test_async/test_api/test_backend/_fake_backends.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Coroutine, Sequence from socket import socket as Socket -from typing import Any, AsyncContextManager, NoReturn, final +from typing import Any, NoReturn, final from easynetwork.api_async.backend.abc import ( AsyncBackend, @@ -12,16 +12,13 @@ ICondition, IEvent, ILock, - Runner, - SystemTask, TaskGroup, ThreadsPortal, - TimeoutHandle, ) class BaseFakeBackend(AsyncBackend): - def new_runner(self) -> Runner: + def bootstrap(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError async def sleep(self, delay: float) -> None: @@ -42,22 +39,22 @@ async def cancel_shielded_coro_yield(self) -> None: async def ignore_cancellation(self, coroutine: Coroutine[Any, Any, Any]) -> Any: raise NotImplementedError - def timeout(self, delay: Any) -> AsyncContextManager[TimeoutHandle]: + def timeout(self, delay: Any) -> Any: raise NotImplementedError - def timeout_at(self, deadline: Any) -> AsyncContextManager[TimeoutHandle]: + def timeout_at(self, deadline: Any) -> Any: raise NotImplementedError - def move_on_after(self, delay: Any) -> AsyncContextManager[TimeoutHandle]: + def move_on_after(self, delay: Any) -> Any: raise NotImplementedError - def move_on_at(self, deadline: Any) -> AsyncContextManager[TimeoutHandle]: + def move_on_at(self, deadline: Any) -> Any: raise NotImplementedError - def get_cancelled_exc_class(self) -> type[BaseException]: + def open_cancel_scope(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError - def spawn_task(self, *args: Any, **kwargs: Any) -> SystemTask[Any]: + def get_cancelled_exc_class(self) -> type[BaseException]: raise NotImplementedError def create_task_group(self) -> TaskGroup: diff --git a/tests/unit_test/test_async/test_api/test_backend/test_backend.py b/tests/unit_test/test_async/test_api/test_backend/test_backend.py index dd2fc3df..bc5d4b67 100644 --- a/tests/unit_test/test_async/test_api/test_backend/test_backend.py +++ b/tests/unit_test/test_async/test_api/test_backend/test_backend.py @@ -102,9 +102,9 @@ def reset_factory_cache_at_end() -> Iterator[None]: yield AsyncBackendFactory.invalidate_backends_cache() - from easynetwork_asyncio import AsyncioBackend + from easynetwork_asyncio import AsyncIOBackend - assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncioBackend} + assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncIOBackend} @pytest.fixture(autouse=True) @staticmethod diff --git a/tests/unit_test/test_async/test_api/test_backend/test_futures.py b/tests/unit_test/test_async/test_api/test_backend/test_futures.py index 1b79567f..0ba0124b 100644 --- a/tests/unit_test/test_async/test_api/test_backend/test_futures.py +++ b/tests/unit_test/test_async/test_api/test_backend/test_futures.py @@ -36,22 +36,10 @@ def executor_handle_contexts(request: pytest.FixtureRequest) -> bool: def executor(mock_backend: MagicMock, mock_stdlib_executor: MagicMock, executor_handle_contexts: bool) -> AsyncExecutor: return AsyncExecutor(mock_stdlib_executor, mock_backend, handle_contexts=executor_handle_contexts) - @pytest.fixture - @staticmethod - def mock_context(mocker: MockerFixture) -> MagicMock: - return mocker.NonCallableMagicMock(spec=contextvars.Context) - @pytest.fixture(autouse=True) @staticmethod - def mock_contextvars_copy_context( - mock_context: MagicMock, - mocker: MockerFixture, - ) -> MagicMock: - return mocker.patch( - "contextvars.copy_context", - autospec=True, - return_value=mock_context, - ) + def mock_contextvars_copy_context(mocker: MockerFixture) -> MagicMock: + return mocker.patch("contextvars.copy_context", autospec=True) async def test___dunder_init___invalid_executor( self, @@ -71,13 +59,18 @@ async def test____run____submit_to_executor_and_wait( executor_handle_contexts: bool, mock_backend: MagicMock, mock_stdlib_executor: MagicMock, - mock_context: MagicMock, mock_contextvars_copy_context: MagicMock, mocker: MockerFixture, ) -> None: # Arrange + mock_context: MagicMock = mocker.NonCallableMagicMock(spec=contextvars.Context) + mock_contextvars_copy_context.return_value = mock_context func = mocker.stub() - mock_stdlib_executor.submit.return_value = mocker.sentinel.future + mock_future = mocker.NonCallableMagicMock( + spec=concurrent.futures.Future, + **{"cancel.return_value": False}, + ) + mock_stdlib_executor.submit.return_value = mock_future mock_backend.wait_future.return_value = mocker.sentinel.result # Act @@ -114,9 +107,75 @@ async def test____run____submit_to_executor_and_wait( kw2=mocker.sentinel.kw2, ) func.assert_not_called() - mock_backend.wait_future.assert_awaited_once_with(mocker.sentinel.future) + mock_backend.wait_future.assert_awaited_once_with(mock_future) + mock_future.cancel.assert_called_once_with() assert result is mocker.sentinel.result + @pytest.mark.parametrize("future_exception", [Exception, None]) + async def test____map____submit_to_executor_and_wait( + self, + executor: AsyncExecutor, + executor_handle_contexts: bool, + future_exception: type[BaseException] | None, + mock_backend: MagicMock, + mock_stdlib_executor: MagicMock, + mock_contextvars_copy_context: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_contexts: list[MagicMock] = [mocker.NonCallableMagicMock(spec=contextvars.Context) for _ in range(3)] + mock_futures: list[MagicMock] = [ + mocker.NonCallableMagicMock( + spec=concurrent.futures.Future, + **{"result.return_value": i, "cancel.return_value": False}, + ) + for i in range(3) + ] + mock_contextvars_copy_context.side_effect = mock_contexts + func = mocker.stub() + mock_stdlib_executor.submit.side_effect = mock_futures + if future_exception is None: + mock_backend.wait_future.side_effect = [mocker.sentinel.result_1, mocker.sentinel.result_2, mocker.sentinel.result_3] + else: + mock_backend.wait_future.side_effect = future_exception() + func_args = (mocker.sentinel.arg1, mocker.sentinel.arg2, mocker.sentinel.args3) + + # Act + if future_exception is None: + results = [result async for result in executor.map(func, func_args)] + else: + results = [] + with pytest.raises(Exception) as exc_info: + results = [result async for result in executor.map(func, func_args)] + assert exc_info.value is mock_backend.wait_future.side_effect + + # Assert + if executor_handle_contexts: + assert mock_contextvars_copy_context.call_args_list == [mocker.call() for _ in range(len(mock_contexts))] + if current_async_library_cvar is not None: + for mock_context in mock_contexts: + mock_context.run.assert_called_once_with(current_async_library_cvar.set, None) + else: + for mock_context in mock_contexts: + mock_context.run.assert_not_called() + assert mock_stdlib_executor.submit.call_args_list == [ + mocker.call(partial_eq(mock_context.run, func), arg) for mock_context, arg in zip(mock_contexts, func_args) + ] + else: + mock_contextvars_copy_context.assert_not_called() + for mock_context in mock_contexts: + mock_context.run.assert_not_called() + assert mock_stdlib_executor.submit.call_args_list == [mocker.call(func, arg) for arg in func_args] + func.assert_not_called() + if future_exception is None: + mock_backend.wait_future.await_args_list == [mocker.call(mock_fut) for mock_fut in mock_futures] + assert results == [mocker.sentinel.result_1, mocker.sentinel.result_2, mocker.sentinel.result_3] + else: + mock_backend.wait_future.await_args_list == [mocker.call(mock_futures[0])] + assert results == [] + for mock_fut in mock_futures: + mock_fut.cancel.assert_called_once_with() + async def test____shutdown_nowait____shutdown_executor( self, executor: AsyncExecutor, diff --git a/tests/unit_test/test_async/test_api/test_client/base.py b/tests/unit_test/test_async/test_api/test_client/base.py index dd1a7f24..f7ea6fe9 100644 --- a/tests/unit_test/test_async/test_api/test_client/base.py +++ b/tests/unit_test/test_async/test_api/test_client/base.py @@ -1,7 +1,7 @@ from __future__ import annotations -from ...base import BaseTestAsyncSocketAdapter +from ....base import BaseTestSocket -class BaseTestClient(BaseTestAsyncSocketAdapter): +class BaseTestClient(BaseTestSocket): pass diff --git a/tests/unit_test/test_async/test_api/test_client/test_tcp.py b/tests/unit_test/test_async/test_api/test_client/test_tcp.py index a1ed94c7..9c6a8e2a 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_tcp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_tcp.py @@ -69,12 +69,12 @@ def mock_new_backend(mocker: MockerFixture, mock_backend: MagicMock) -> MagicMoc @classmethod def local_address( cls, - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, socket_family: int, global_local_address: tuple[str, int], ) -> tuple[str, int]: - cls.set_local_address_to_async_socket_adapter_mock( - mock_stream_socket_adapter, + cls.set_local_address_to_socket_mock( + mock_tcp_socket, socket_family, global_local_address, ) @@ -84,12 +84,12 @@ def local_address( @classmethod def remote_address( cls, - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, socket_family: int, global_remote_address: tuple[str, int], ) -> tuple[str, int]: - cls.set_remote_address_to_async_socket_adapter_mock( - mock_stream_socket_adapter, + cls.set_remote_address_to_socket_mock( + mock_tcp_socket, socket_family, global_remote_address, ) @@ -105,8 +105,6 @@ def set_default_socket_mock_configuration( ) -> None: mock_tcp_socket.family = socket_family mock_tcp_socket.getsockopt.return_value = 0 # Needed for tests dealing with send_packet() - del mock_tcp_socket.getsockname - del mock_tcp_socket.getpeername del mock_tcp_socket.sendall del mock_tcp_socket.recv @@ -208,8 +206,8 @@ async def test____dunder_init____connect_to_remote( happy_eyeballs_delay=mocker.sentinel.happy_eyeballs_delay, ) mock_stream_socket_adapter.socket.assert_called_once_with() - mock_stream_socket_adapter.get_local_address.assert_called_once_with() - mock_stream_socket_adapter.get_remote_address.assert_called_once_with() + mock_tcp_socket.getsockname.assert_called_once_with() + mock_tcp_socket.getpeername.assert_called_once_with() assert mock_tcp_socket.setsockopt.mock_calls == [ mocker.call(IPPROTO_TCP, TCP_NODELAY, True), mocker.call(SOL_SOCKET, SO_KEEPALIVE, True), @@ -275,8 +273,8 @@ async def test____dunder_init____use_given_socket( mock_stream_data_consumer_cls.assert_called_once_with(mock_stream_protocol) mock_backend.wrap_tcp_client_socket.assert_awaited_once_with(mock_tcp_socket) mock_stream_socket_adapter.socket.assert_called_once_with() - mock_stream_socket_adapter.get_local_address.assert_called_once_with() - mock_stream_socket_adapter.get_remote_address.assert_called_once_with() + mock_tcp_socket.getsockname.assert_called_once_with() + mock_tcp_socket.getpeername.assert_called_once_with() assert mock_tcp_socket.setsockopt.mock_calls == [ mocker.call(IPPROTO_TCP, TCP_NODELAY, True), mocker.call(SOL_SOCKET, SO_KEEPALIVE, True), @@ -451,8 +449,8 @@ async def test____dunder_init____ssl( happy_eyeballs_delay=mocker.sentinel.happy_eyeballs_delay, ) mock_stream_socket_adapter.socket.assert_called_once_with() - mock_stream_socket_adapter.get_local_address.assert_called_once_with() - mock_stream_socket_adapter.get_remote_address.assert_called_once_with() + mock_tcp_socket.getsockname.assert_called_once_with() + mock_tcp_socket.getpeername.assert_called_once_with() assert mock_tcp_socket.setsockopt.mock_calls == [ mocker.call(IPPROTO_TCP, TCP_NODELAY, True), mocker.call(SOL_SOCKET, SO_KEEPALIVE, True), @@ -886,10 +884,10 @@ async def test____get_local_address____return_saved_address( client_closed: bool, socket_family: int, local_address: tuple[str, int], - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, ) -> None: # Arrange - mock_stream_socket_adapter.get_local_address.reset_mock() + mock_tcp_socket.getsockname.reset_mock() if client_closed: await client_connected.aclose() assert client_connected.is_closing() @@ -902,14 +900,14 @@ async def test____get_local_address____return_saved_address( assert isinstance(address, IPv6SocketAddress) else: assert isinstance(address, IPv4SocketAddress) - mock_stream_socket_adapter.get_local_address.assert_not_called() + mock_tcp_socket.getsockname.assert_not_called() assert address.host == local_address[0] assert address.port == local_address[1] async def test____get_local_address____error_connection_not_performed( self, client_not_connected: AsyncTCPNetworkClient[Any, Any], - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, ) -> None: # Arrange @@ -918,7 +916,7 @@ async def test____get_local_address____error_connection_not_performed( client_not_connected.get_local_address() # Assert - mock_stream_socket_adapter.get_local_address.assert_not_called() + mock_tcp_socket.getsockname.assert_not_called() @pytest.mark.parametrize("client_closed", [False, True], ids=lambda p: f"client_closed=={p}") async def test____get_remote_address____return_saved_address( @@ -927,11 +925,10 @@ async def test____get_remote_address____return_saved_address( client_closed: bool, remote_address: tuple[str, int], socket_family: int, - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, ) -> None: # Arrange - ## NOTE: The client should have the remote address saved. Therefore this test check if there is no new call. - mock_stream_socket_adapter.get_remote_address.assert_called_once() + mock_tcp_socket.getpeername.reset_mock() if client_closed: await client_connected.aclose() assert client_connected.is_closing() @@ -944,14 +941,14 @@ async def test____get_remote_address____return_saved_address( assert isinstance(address, IPv6SocketAddress) else: assert isinstance(address, IPv4SocketAddress) - mock_stream_socket_adapter.get_remote_address.assert_called_once() + mock_tcp_socket.getpeername.assert_not_called() assert address.host == remote_address[0] assert address.port == remote_address[1] async def test____get_remote_address____error_connection_not_performed( self, client_not_connected: AsyncTCPNetworkClient[Any, Any], - mock_stream_socket_adapter: MagicMock, + mock_tcp_socket: MagicMock, ) -> None: # Arrange @@ -960,7 +957,7 @@ async def test____get_remote_address____error_connection_not_performed( client_not_connected.get_remote_address() # Assert - mock_stream_socket_adapter.get_remote_address.assert_not_called() + mock_tcp_socket.getpeername.assert_not_called() @pytest.mark.usefixtures("setup_producer_mock") async def test____send_packet____send_bytes_to_socket( @@ -1101,25 +1098,18 @@ async def test____send_packet____convert_closed_socket_error( mock_tcp_socket.getsockopt.assert_not_called() @pytest.mark.usefixtures("setup_producer_mock") - @pytest.mark.parametrize("error", [OSError, None]) - @pytest.mark.parametrize("mock_stream_socket_adapter_factory", ["eof_support"], indirect=True) async def test____send_eof____socket_send_eof( self, client_connected_or_not: AsyncTCPNetworkClient[Any, Any], - error: type[BaseException] | None, mock_stream_socket_adapter: MagicMock, mock_stream_protocol: MagicMock, mocker: MockerFixture, ) -> None: # Arrange - if error is None: - mock_stream_socket_adapter.send_eof.return_value = None - else: - mock_stream_socket_adapter.send_eof.side_effect = error + mock_stream_socket_adapter.send_eof.return_value = None # Act - with pytest.raises(error) if error is not None else contextlib.nullcontext(): - await client_connected_or_not.send_eof() + await client_connected_or_not.send_eof() # Assert mock_stream_socket_adapter.send_eof.assert_awaited_once_with() @@ -1130,26 +1120,22 @@ async def test____send_eof____socket_send_eof( mock_stream_socket_adapter.sendall.assert_not_awaited() @pytest.mark.usefixtures("setup_producer_mock") - @pytest.mark.parametrize("mock_stream_socket_adapter_factory", ["eof_support", None], indirect=True) async def test____send_eof____closed_client( self, client_connected_or_not: AsyncTCPNetworkClient[Any, Any], mock_stream_socket_adapter: MagicMock, ) -> None: # Arrange - if hasattr(mock_stream_socket_adapter, "send_eof"): - mock_stream_socket_adapter.send_eof.return_value = None + mock_stream_socket_adapter.send_eof.return_value = None await client_connected_or_not.aclose() # Act await client_connected_or_not.send_eof() # Assert - if hasattr(mock_stream_socket_adapter, "send_eof"): - mock_stream_socket_adapter.send_eof.assert_not_awaited() + mock_stream_socket_adapter.send_eof.assert_not_awaited() @pytest.mark.usefixtures("setup_producer_mock") - @pytest.mark.parametrize("mock_stream_socket_adapter_factory", ["eof_support"], indirect=True) async def test____send_eof____unexpected_socket_close( self, client_connected_or_not: AsyncTCPNetworkClient[Any, Any], @@ -1166,7 +1152,6 @@ async def test____send_eof____unexpected_socket_close( mock_stream_socket_adapter.send_eof.assert_not_awaited() @pytest.mark.usefixtures("setup_producer_mock") - @pytest.mark.parametrize("mock_stream_socket_adapter_factory", ["eof_support"], indirect=True) async def test____send_eof____idempotent( self, client_connected_or_not: AsyncTCPNetworkClient[Any, Any], @@ -1182,27 +1167,6 @@ async def test____send_eof____idempotent( # Assert mock_stream_socket_adapter.send_eof.assert_awaited_once() - @pytest.mark.usefixtures("setup_producer_mock") - async def test____send_eof____not_supported( - self, - client_connected_or_not: AsyncTCPNetworkClient[Any, Any], - mock_stream_socket_adapter: MagicMock, - mock_stream_protocol: MagicMock, - mocker: MockerFixture, - ) -> None: - # Arrange - assert not hasattr(mock_stream_socket_adapter, "send_eof") - - # Act - with pytest.raises(NotImplementedError): - await client_connected_or_not.send_eof() - - # Assert - await client_connected_or_not.send_packet(mocker.sentinel.packet) - mock_stream_protocol.generate_chunks.assert_called() - mock_stream_socket_adapter.sendall_fromiter.assert_called() - mock_stream_socket_adapter.sendall.assert_awaited() - @pytest.mark.usefixtures("setup_consumer_mock") async def test____recv_packet____receive_bytes_from_socket( self, diff --git a/tests/unit_test/test_async/test_api/test_client/test_udp.py b/tests/unit_test/test_async/test_api/test_client/test_udp.py index 3815dd7e..f72e0377 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_udp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_udp.py @@ -51,17 +51,11 @@ def global_remote_address() -> tuple[str, int]: @classmethod def local_address( cls, - mock_datagram_socket_adapter: MagicMock, mock_udp_socket: MagicMock, socket_family: int, global_local_address: tuple[str, int], ) -> tuple[str, int]: cls.set_local_address_to_socket_mock(mock_udp_socket, socket_family, global_local_address) - cls.set_local_address_to_async_socket_adapter_mock( - mock_datagram_socket_adapter, - socket_family, - global_local_address, - ) return global_local_address @pytest.fixture(autouse=True, params=[False, True], ids=lambda p: f"remote_address=={p}") @@ -69,20 +63,16 @@ def local_address( def remote_address( cls, request: Any, - mock_datagram_socket_adapter: MagicMock, + mock_udp_socket: MagicMock, socket_family: int, global_remote_address: tuple[str, int], ) -> tuple[str, int] | None: match request.param: case True: - cls.set_remote_address_to_async_socket_adapter_mock( - mock_datagram_socket_adapter, - socket_family, - global_remote_address, - ) + cls.set_remote_address_to_socket_mock(mock_udp_socket, socket_family, global_remote_address) return global_remote_address case False: - mock_datagram_socket_adapter.get_remote_address.return_value = None + cls.configure_socket_mock_to_raise_ENOTCONN(mock_udp_socket) return None case invalid: pytest.fail(f"Invalid fixture param: Got {invalid!r}") @@ -97,7 +87,6 @@ def set_default_socket_mock_configuration( ) -> None: mock_udp_socket.family = socket_family mock_udp_socket.getsockopt.return_value = 0 # Needed for tests dealing with send_packet_to() - del mock_udp_socket.getpeername del mock_udp_socket.send del mock_udp_socket.sendto del mock_udp_socket.recvfrom @@ -161,6 +150,7 @@ def sender_address(request: Any, global_remote_address: tuple[str, int]) -> tupl async def test____dunder_init____with_remote_address( self, remote_address: tuple[str, int] | None, + mock_udp_socket: MagicMock, mock_datagram_socket_adapter: MagicMock, mock_datagram_protocol: MagicMock, mock_new_backend: MagicMock, @@ -186,8 +176,8 @@ async def test____dunder_init____with_remote_address( reuse_port=mocker.sentinel.reuse_port, ) mock_datagram_socket_adapter.socket.assert_called_once_with() - mock_datagram_socket_adapter.get_local_address.assert_called_once_with() - mock_datagram_socket_adapter.get_remote_address.assert_called_once_with() + mock_udp_socket.getsockname.assert_called_once_with() + mock_udp_socket.getpeername.assert_called_once_with() assert isinstance(client.socket, SocketProxy) async def test____dunder_init____with_remote_address____force_local_address( @@ -257,6 +247,7 @@ async def test____dunder_init____use_given_socket( mock_datagram_protocol: MagicMock, mock_new_backend: MagicMock, mock_backend: MagicMock, + mocker: MockerFixture, ) -> None: # Arrange @@ -272,8 +263,8 @@ async def test____dunder_init____use_given_socket( mock_new_backend.assert_called_once_with(None) mock_backend.wrap_udp_socket.assert_awaited_once_with(mock_udp_socket) mock_datagram_socket_adapter.socket.assert_called_once_with() - mock_datagram_socket_adapter.get_local_address.assert_called_once_with() - mock_datagram_socket_adapter.get_remote_address.assert_called_once_with() + assert mock_udp_socket.getsockname.call_args_list == [mocker.call() for _ in range(2)] + mock_udp_socket.getpeername.assert_called_once_with() assert isinstance(client.socket, SocketProxy) async def test____dunder_init____use_given_socket____force_local_address( @@ -464,10 +455,10 @@ async def test____get_local_address____return_saved_address( socket_family: int, local_address: tuple[str, int], client_bound: AsyncUDPNetworkClient[Any, Any], - mock_datagram_socket_adapter: MagicMock, + mock_udp_socket: MagicMock, ) -> None: # Arrange - mock_datagram_socket_adapter.get_local_address.reset_mock() + mock_udp_socket.getsockname.reset_mock() if client_closed: await client_bound.aclose() assert client_bound.is_closing() @@ -480,23 +471,24 @@ async def test____get_local_address____return_saved_address( assert isinstance(address, IPv6SocketAddress) else: assert isinstance(address, IPv4SocketAddress) - mock_datagram_socket_adapter.get_local_address.assert_not_called() + mock_udp_socket.getsockname.assert_not_called() assert address.host == local_address[0] assert address.port == local_address[1] async def test____get_local_address____error_connection_not_performed( self, client_not_bound: AsyncUDPNetworkEndpoint[Any, Any], - mock_datagram_socket_adapter: MagicMock, + mock_udp_socket: MagicMock, ) -> None: # Arrange + mock_udp_socket.getsockname.reset_mock() # Act with pytest.raises(OSError): client_not_bound.get_local_address() # Assert - mock_datagram_socket_adapter.get_local_address.assert_not_called() + mock_udp_socket.getsockname.assert_not_called() @pytest.mark.parametrize("client_closed", [False, True], ids=lambda p: f"client_closed=={p}") async def test____get_remote_address____return_saved_address( @@ -505,11 +497,10 @@ async def test____get_remote_address____return_saved_address( remote_address: tuple[str, int] | None, socket_family: int, client_bound: AsyncUDPNetworkClient[Any, Any], - mock_datagram_socket_adapter: MagicMock, + mock_udp_socket: MagicMock, ) -> None: # Arrange - ## NOTE: The client should have the remote address saved. Therefore this test check if there is no new call. - mock_datagram_socket_adapter.get_remote_address.assert_called_once() + mock_udp_socket.getpeername.reset_mock() if client_closed: await client_bound.aclose() assert client_bound.is_closing() @@ -527,12 +518,12 @@ async def test____get_remote_address____return_saved_address( assert isinstance(address, IPv4SocketAddress) assert address.host == remote_address[0] assert address.port == remote_address[1] - mock_datagram_socket_adapter.get_remote_address.assert_called_once() + mock_udp_socket.getpeername.assert_not_called() async def test____get_remote_address____error_connection_not_performed( self, client_not_bound: AsyncUDPNetworkEndpoint[Any, Any], - mock_datagram_socket_adapter: MagicMock, + mock_udp_socket: MagicMock, ) -> None: # Arrange @@ -541,7 +532,7 @@ async def test____get_remote_address____error_connection_not_performed( client_not_bound.get_remote_address() # Assert - mock_datagram_socket_adapter.get_remote_address.assert_not_called() + mock_udp_socket.getpeername.assert_not_called() @pytest.mark.parametrize("remote_address", [False], indirect=True) @pytest.mark.usefixtures("setup_protocol_mock") diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py index 18d2e19d..2fc12836 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py @@ -3,12 +3,12 @@ import asyncio import contextlib import contextvars -from collections.abc import Callable, Sequence +from collections.abc import Callable, Coroutine, Sequence from socket import AF_INET from typing import TYPE_CHECKING, Any, cast from easynetwork.api_async.backend.abc import AsyncStreamSocketAdapter -from easynetwork_asyncio import AsyncioBackend +from easynetwork_asyncio import AsyncIOBackend import pytest @@ -20,6 +20,56 @@ from pytest_mock import MockerFixture +class TestAsyncIOBackendSync: + @pytest.fixture + @staticmethod + def backend() -> AsyncIOBackend: + return AsyncIOBackend() + + @pytest.mark.parametrize("runner_options", [{"loop_factory": 42}, None]) + def test____bootstrap____start_new_runner( + self, + runner_options: dict[str, Any] | None, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_asyncio_runner: MagicMock = mocker.NonCallableMagicMock( + spec=asyncio.Runner, + **{"run.return_value": mocker.sentinel.Runner_ret_val}, + ) + mock_asyncio_runner.__enter__.return_value = mock_asyncio_runner + mock_asyncio_runner_cls = mocker.patch("asyncio.Runner", side_effect=[mock_asyncio_runner]) + mock_coroutine = mocker.NonCallableMagicMock(spec=Coroutine) + coro_stub = mocker.stub() + coro_stub.return_value = mock_coroutine + + # Act + ret_val = backend.bootstrap( + coro_stub, + mocker.sentinel.arg1, + mocker.sentinel.arg2, + mocker.sentinel.arg3, + runner_options=runner_options, + ) + + # Assert + if runner_options is None: + mock_asyncio_runner_cls.assert_called_once_with() + else: + mock_asyncio_runner_cls.assert_called_once_with(**runner_options) + + coro_stub.assert_called_once_with( + mocker.sentinel.arg1, + mocker.sentinel.arg2, + mocker.sentinel.arg3, + ) + + mock_asyncio_runner.run.assert_called_once_with(mock_coroutine) + mock_coroutine.close.assert_called_once_with() + assert ret_val is mocker.sentinel.Runner_ret_val + + @pytest.mark.asyncio class TestAsyncIOBackend: @pytest.fixture(params=[False, True], ids=lambda boolean: f"use_asyncio_transport=={boolean}") @@ -29,8 +79,8 @@ def use_asyncio_transport(request: Any) -> bool: @pytest.fixture @staticmethod - def backend(use_asyncio_transport: bool) -> AsyncioBackend: - return AsyncioBackend(transport=use_asyncio_transport) + def backend(use_asyncio_transport: bool) -> AsyncIOBackend: + return AsyncIOBackend(transport=use_asyncio_transport) @pytest.fixture(params=[("local_address", 12345), None], ids=lambda addr: f"local_address=={addr}") @staticmethod @@ -44,7 +94,7 @@ def remote_address(request: Any) -> tuple[str, int] | None: async def test____use_asyncio_transport____follows_option( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, ) -> None: assert backend.using_asyncio_transport() == use_asyncio_transport @@ -53,7 +103,7 @@ async def test____use_asyncio_transport____follows_option( async def test____coro_yield____use_asyncio_sleep( self, cancel_shielded: bool, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -70,7 +120,7 @@ async def test____coro_yield____use_asyncio_sleep( async def test____get_cancelled_exc_class____returns_asyncio_CancelledError( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: # Arrange @@ -80,7 +130,7 @@ async def test____get_cancelled_exc_class____returns_asyncio_CancelledError( async def test____current_time____use_event_loop_time( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -95,7 +145,7 @@ async def test____current_time____use_event_loop_time( async def test____sleep____use_asyncio_sleep( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -109,7 +159,7 @@ async def test____sleep____use_asyncio_sleep( async def test____ignore_cancellation____not_a_coroutine( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -128,7 +178,7 @@ async def test____create_tcp_connection____use_asyncio_open_connection( ssl: bool, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_asyncio_stream_reader_factory: Callable[[], MagicMock], mock_asyncio_stream_writer_factory: Callable[[], MagicMock], mock_ssl_context: MagicMock, @@ -191,7 +241,7 @@ async def test____create_tcp_connection____use_asyncio_open_connection____no_hap ssl: bool, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_asyncio_stream_reader_factory: Callable[[], MagicMock], mock_asyncio_stream_writer_factory: Callable[[], MagicMock], mock_ssl_context: MagicMock, @@ -246,7 +296,7 @@ async def test____create_tcp_connection____use_asyncio_open_connection____happy_ ssl: bool, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_asyncio_stream_reader_factory: Callable[[], MagicMock], mock_asyncio_stream_writer_factory: Callable[[], MagicMock], mock_ssl_context: MagicMock, @@ -301,7 +351,7 @@ async def test____create_tcp_connection____creates_raw_socket_adapter( event_loop: asyncio.AbstractEventLoop, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_tcp_socket: Callable[[], MagicMock], mocker: MockerFixture, ) -> None: @@ -341,7 +391,7 @@ async def test____create_tcp_connection____happy_eyeballs_delay_not_supported( self, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -377,7 +427,7 @@ async def test____create_ssl_over_tcp_connection____ssl_not_supported( self, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_ssl_context: MagicMock, mocker: MockerFixture, ) -> None: @@ -422,7 +472,7 @@ async def test____create_ssl_over_tcp_connection____invalid_ssl_context_value( self, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -457,7 +507,7 @@ async def test____create_ssl_over_tcp_connection____no_ssl_module( self, local_address: tuple[str, int] | None, remote_address: tuple[str, int], - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -489,7 +539,7 @@ async def test____create_ssl_over_tcp_connection____no_ssl_module( async def test____wrap_tcp_client_socket____use_asyncio_open_connection( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_tcp_socket: MagicMock, mock_asyncio_stream_reader_factory: Callable[[], MagicMock], @@ -530,7 +580,7 @@ async def test____wrap_tcp_client_socket____use_asyncio_open_connection( async def test____wrap_ssl_over_tcp_client_socket____use_asyncio_open_connection( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_tcp_socket: MagicMock, mock_asyncio_stream_reader_factory: Callable[[], MagicMock], @@ -586,7 +636,7 @@ async def test____wrap_ssl_over_tcp_client_socket____use_asyncio_open_connection async def test____wrap_ssl_over_tcp_client_socket____ssl_not_supported( self, mock_tcp_socket: MagicMock, - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_ssl_context: MagicMock, mocker: MockerFixture, ) -> None: @@ -628,7 +678,7 @@ async def test____wrap_ssl_over_tcp_client_socket____ssl_not_supported( async def test____wrap_ssl_over_tcp_client_socket____invalid_ssl_context_value( self, mock_tcp_socket: MagicMock, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -660,7 +710,7 @@ async def test____wrap_ssl_over_tcp_client_socket____invalid_ssl_context_value( async def test____wrap_ssl_over_tcp_client_socket____no_ssl_module( self, mock_tcp_socket: MagicMock, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -699,7 +749,7 @@ async def test____wrap_ssl_over_tcp_client_socket____no_ssl_module( async def test____create_tcp_listeners____open_listener_sockets( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_tcp_socket: MagicMock, use_ssl: bool, @@ -803,7 +853,7 @@ async def test____create_tcp_listeners____bind_to_any_interfaces( self, remote_host: str, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_tcp_socket: MagicMock, use_ssl: bool, @@ -915,7 +965,7 @@ async def test____create_tcp_listeners____bind_to_any_interfaces( async def test____create_tcp_listeners____bind_to_several_hosts( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_tcp_socket: MagicMock, use_ssl: bool, @@ -1031,7 +1081,7 @@ async def test____create_tcp_listeners____bind_to_several_hosts( async def test____create_tcp_listeners____error_getaddrinfo_returns_empty_list( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_ssl: bool, mock_ssl_context: MagicMock, mocker: MockerFixture, @@ -1095,7 +1145,7 @@ async def test____create_tcp_listeners____error_getaddrinfo_returns_empty_list( async def test____create_ssl_over_tcp_listeners____ssl_not_supported( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_ssl_context: MagicMock, mocker: MockerFixture, ) -> None: @@ -1143,7 +1193,7 @@ async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( event_loop: asyncio.AbstractEventLoop, local_address: tuple[str, int] | None, remote_address: tuple[str, int] | None, - backend: AsyncioBackend, + backend: AsyncIOBackend, mock_datagram_endpoint_factory: Callable[[], MagicMock], use_asyncio_transport: bool, mock_udp_socket: MagicMock, @@ -1201,7 +1251,7 @@ async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( async def test____wrap_udp_socket____use_loop_create_datagram_endpoint( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, use_asyncio_transport: bool, mock_udp_socket: MagicMock, mock_datagram_endpoint_factory: Callable[[], MagicMock], @@ -1241,7 +1291,7 @@ async def test____wrap_udp_socket____use_loop_create_datagram_endpoint( async def test____create_lock____use_asyncio_Lock_class( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -1256,7 +1306,7 @@ async def test____create_lock____use_asyncio_Lock_class( async def test____create_event____use_asyncio_Event_class( self, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -1273,7 +1323,7 @@ async def test____create_event____use_asyncio_Event_class( async def test____create_condition_var____use_asyncio_Condition_class( self, use_lock: type[asyncio.Lock] | None, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -1290,7 +1340,7 @@ async def test____create_condition_var____use_asyncio_Condition_class( async def test____run_in_thread____use_loop_run_in_executor( self, event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: # Arrange @@ -1328,8 +1378,7 @@ async def test____run_in_thread____use_loop_run_in_executor( async def test____create_threads_portal____returns_asyncio_portal( self, - event_loop: asyncio.AbstractEventLoop, - backend: AsyncioBackend, + backend: AsyncIOBackend, ) -> None: # Arrange from easynetwork_asyncio.threads import ThreadsPortal @@ -1339,4 +1388,3 @@ async def test____create_threads_portal____returns_asyncio_portal( # Assert assert isinstance(threads_portal, ThreadsPortal) - assert threads_portal.loop is event_loop diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py b/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py index 3d51666c..7aa875ab 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_datagram.py @@ -2,9 +2,9 @@ import asyncio from collections.abc import Callable -from errno import ENOTSOCK +from errno import ECONNABORTED from socket import AI_PASSIVE -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast from easynetwork_asyncio.datagram.endpoint import DatagramEndpoint, DatagramEndpointProtocol, create_datagram_endpoint from easynetwork_asyncio.datagram.socket import AsyncioTransportDatagramSocketAdapter, RawDatagramSocketAdapter @@ -90,10 +90,11 @@ async def test____create_datagram_endpoint____return_DatagramEndpoint_instance( class TestDatagramEndpoint: @pytest.fixture @staticmethod - def mock_asyncio_protocol(mocker: MockerFixture) -> MagicMock: + def mock_asyncio_protocol(mocker: MockerFixture, event_loop: asyncio.AbstractEventLoop) -> MagicMock: mock = mocker.NonCallableMagicMock(spec=DatagramEndpointProtocol) # Currently, _get_close_waiter() is a synchronous function returning a Future, but it will be awaited so this works mock._get_close_waiter = mocker.AsyncMock() + mock._get_loop.return_value = event_loop return mock @pytest.fixture @@ -216,10 +217,11 @@ async def test____recvfrom____await_recv_queue( # Assert mock_asyncio_recv_queue.get.assert_awaited_once_with() + mock_asyncio_recv_queue.get_nowait.assert_not_called() assert data == b"some data" assert address == ("an_address", 12345) - async def test____recvfrom____connection_lost____transport_already_closed( + async def test____recvfrom____connection_lost____transport_already_closed____data_in_queue( self, endpoint: DatagramEndpoint, mock_asyncio_transport: MagicMock, @@ -227,11 +229,41 @@ async def test____recvfrom____connection_lost____transport_already_closed( mock_asyncio_exception_queue: MagicMock, ) -> None: # Arrange - from errno import ECONNABORTED + mock_asyncio_exception_queue.get_nowait.side_effect = asyncio.QueueEmpty + mock_asyncio_transport.is_closing.return_value = True + mock_asyncio_recv_queue.get_nowait.return_value = (b"some data", ("an_address", 12345)) + + # Act + data, address = await endpoint.recvfrom() + + # Assert + mock_asyncio_exception_queue.get_nowait.assert_called_once() + mock_asyncio_recv_queue.get.assert_not_awaited() + mock_asyncio_recv_queue.get_nowait.assert_called_once() + assert data == b"some data" + assert address == ("an_address", 12345) + @pytest.mark.parametrize("condition", ["empty_queue", "None_pushed"]) + async def test____recvfrom____connection_lost____transport_already_closed____no_more_data( + self, + endpoint: DatagramEndpoint, + condition: Literal["empty_queue", "None_pushed"], + mock_asyncio_transport: MagicMock, + mock_asyncio_recv_queue: MagicMock, + mock_asyncio_exception_queue: MagicMock, + ) -> None: + # Arrange mock_asyncio_exception_queue.get_nowait.side_effect = asyncio.QueueEmpty mock_asyncio_transport.is_closing.return_value = True + match condition: + case "empty_queue": + mock_asyncio_recv_queue.get_nowait.side_effect = asyncio.QueueEmpty + case "None_pushed": + mock_asyncio_recv_queue.get_nowait.side_effect = [None] + case _: + pytest.fail("Invalid condition") + # Act with pytest.raises(OSError) as exc_info: await endpoint.recvfrom() @@ -240,6 +272,7 @@ async def test____recvfrom____connection_lost____transport_already_closed( assert exc_info.value.errno == ECONNABORTED mock_asyncio_exception_queue.get_nowait.assert_called_once() mock_asyncio_recv_queue.get.assert_not_awaited() + mock_asyncio_recv_queue.get_nowait.assert_called_once() async def test____recvfrom____connection_lost____transport_closed_by_protocol_while_waiting( self, @@ -374,22 +407,6 @@ async def test____get_extra_info____get_transport_extra_info( mock_asyncio_transport.get_extra_info.assert_called_once_with(mocker.sentinel.name, mocker.sentinel.default) assert value is mocker.sentinel.extra_info - async def test____get_loop____get_protocol_bound_loop( - self, - endpoint: DatagramEndpoint, - mock_asyncio_protocol: MagicMock, - mocker: MockerFixture, - ) -> None: - # Arrange - mock_asyncio_protocol._get_loop.return_value = mocker.sentinel.event_loop - - # Act - loop = endpoint.get_loop() - - # Assert - mock_asyncio_protocol._get_loop.assert_called_once_with() - assert loop is mocker.sentinel.event_loop - class TestDatagramEndpointProtocol: @pytest.fixture @@ -820,50 +837,6 @@ async def test____sendto____write_and_drain( # Assert mock_endpoint.sendto.assert_awaited_once_with(b"data to send", address) - async def test____getsockname____return_sockname_extra_info( - self, - socket: AsyncioTransportDatagramSocketAdapter, - endpoint_extra_info: dict[str, Any], - ) -> None: - # Arrange - - # Act - laddr = socket.get_local_address() - - # Assert - assert laddr == endpoint_extra_info["sockname"] - - async def test____getsockname____transport_closed( - self, - socket: AsyncioTransportDatagramSocketAdapter, - endpoint_extra_info: dict[str, Any], - ) -> None: - # Arrange - ### asyncio.DatagramTransport implementations will clear the extra dict on close() - endpoint_extra_info.clear() - - # Act & Assert - with pytest.raises(OSError) as exc_info: - socket.get_local_address() - - assert exc_info.value.errno == ENOTSOCK - - @pytest.mark.parametrize("peername", [("127.0.0.1", 9999), None]) - async def test____getpeername____return_peername_extra_info( - self, - peername: tuple[Any, ...] | None, - socket: AsyncioTransportDatagramSocketAdapter, - endpoint_extra_info: dict[str, Any], - ) -> None: - # Arrange - endpoint_extra_info["peername"] = peername - - # Act - raddr = socket.get_remote_address() - - # Assert - assert raddr == peername - async def test____socket____returns_transport_socket( self, socket: AsyncioTransportDatagramSocketAdapter, @@ -929,49 +902,6 @@ async def test____dunder_init____invalid_socket_type( with pytest.raises(ValueError, match=r"^A 'SOCK_DGRAM' socket is expected$"): _ = RawDatagramSocketAdapter(mock_tcp_socket, event_loop) - async def test____get_local_address____returns_socket_address( - self, - socket: RawDatagramSocketAdapter, - mock_udp_socket: MagicMock, - ) -> None: - # Arrange - - # Act - local_address = socket.get_local_address() - - # Assert - mock_udp_socket.getsockname.assert_called_once_with() - assert local_address == ("127.0.0.1", 11111) - - async def test____get_remote_address____returns_peer_address( - self, - socket: RawDatagramSocketAdapter, - mock_udp_socket: MagicMock, - ) -> None: - # Arrange - - # Act - remote_address = socket.get_remote_address() - - # Assert - mock_udp_socket.getpeername.assert_called_once_with() - assert remote_address == ("127.0.0.1", 12345) - - async def test____get_remote_address____no_peer_address( - self, - socket: RawDatagramSocketAdapter, - mock_udp_socket: MagicMock, - ) -> None: - # Arrange - self.configure_socket_mock_to_raise_ENOTCONN(mock_udp_socket) - - # Act - remote_address = socket.get_remote_address() - - # Assert - mock_udp_socket.getpeername.assert_called_once_with() - assert remote_address is None - async def test____is_closing____default( self, socket: RawDatagramSocketAdapter, diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index df0667f3..f9561b28 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -4,19 +4,13 @@ import asyncio import asyncio.trsock -import contextlib from collections.abc import Callable -from errno import ENOTCONN, ENOTSOCK from socket import SHUT_WR from typing import TYPE_CHECKING, Any from easynetwork.api_async.backend.abc import AcceptedSocket as AbstractAcceptedSocket from easynetwork_asyncio.stream.listener import AcceptedSocket, AcceptedSSLSocket, ListenerSocketAdapter -from easynetwork_asyncio.stream.socket import ( - AsyncioTransportHalfCloseableStreamSocketAdapter, - AsyncioTransportStreamSocketAdapter, - RawStreamSocketAdapter, -) +from easynetwork_asyncio.stream.socket import AsyncioTransportStreamSocketAdapter, RawStreamSocketAdapter import pytest @@ -66,7 +60,13 @@ def mock_asyncio_writer( @pytest.mark.asyncio -class BaseTestAsyncioTransportBasedStreamSocket(BaseTestTransportStreamSocket): +class TestAsyncioTransportStreamSocketAdapter(BaseTestTransportStreamSocket): + @pytest.fixture + @staticmethod + def socket(mock_asyncio_reader: MagicMock, mock_asyncio_writer: MagicMock) -> AsyncioTransportStreamSocketAdapter: + mock_asyncio_writer.can_write_eof.return_value = True + return AsyncioTransportStreamSocketAdapter(mock_asyncio_reader, mock_asyncio_writer) + async def test____aclose____close_transport_and_wait( self, socket: AsyncioTransportStreamSocketAdapter, @@ -224,62 +224,19 @@ async def test____sendall_fromiter____writelines_and_drain( mock_asyncio_writer.drain.assert_awaited_once_with() assert written_chunks == [b"data", b"to", b"send"] - async def test____getsockname____return_sockname_extra_info( - self, - socket: AsyncioTransportStreamSocketAdapter, - asyncio_writer_extra_info: dict[str, Any], - ) -> None: - # Arrange - - # Act - laddr = socket.get_local_address() - - # Assert - assert laddr == asyncio_writer_extra_info["sockname"] - - async def test____getsockname____transport_closed( - self, - socket: AsyncioTransportStreamSocketAdapter, - asyncio_writer_extra_info: dict[str, Any], - ) -> None: - # Arrange - ### asyncio.Transport implementations will clear the extra dict on close() - asyncio_writer_extra_info.clear() - - # Act - # Act & Assert - with pytest.raises(OSError) as exc_info: - socket.get_local_address() - - assert exc_info.value.errno == ENOTSOCK - - async def test____getpeername____return_peername_extra_info( + async def test____send_eof____write_eof( self, socket: AsyncioTransportStreamSocketAdapter, - asyncio_writer_extra_info: dict[str, Any], + mock_asyncio_writer: MagicMock, ) -> None: # Arrange + mock_asyncio_writer.write_eof.return_value = None # Act - raddr = socket.get_remote_address() + await socket.send_eof() # Assert - assert raddr == asyncio_writer_extra_info["peername"] - - async def test____getpeername____transport_not_connected( - self, - socket: AsyncioTransportStreamSocketAdapter, - asyncio_writer_extra_info: dict[str, Any], - ) -> None: - # Arrange - ### asyncio.Transport implementations explicitly set peername to None if the socket is not connected - asyncio_writer_extra_info["peername"] = None - - # Act & Assert - with pytest.raises(OSError) as exc_info: - socket.get_remote_address() - - assert exc_info.value.errno == ENOTCONN + mock_asyncio_writer.write_eof.assert_called_once_with() async def test____socket____returns_transport_socket( self, @@ -295,74 +252,6 @@ async def test____socket____returns_transport_socket( assert transport_socket is mock_tcp_socket -@pytest.mark.asyncio -class TestAsyncioTransportStreamSocketAdapter(BaseTestAsyncioTransportBasedStreamSocket): - @pytest.fixture - @staticmethod - def socket(mock_asyncio_reader: MagicMock, mock_asyncio_writer: MagicMock) -> AsyncioTransportStreamSocketAdapter: - mock_asyncio_writer.can_write_eof.return_value = False - return AsyncioTransportStreamSocketAdapter(mock_asyncio_reader, mock_asyncio_writer) - - @pytest.mark.parametrize("can_write_eof", [False, True], ids=lambda p: f"can_write_eof=={p}") - async def test____dunder_new____instanciate_subclass( - self, - can_write_eof: bool, - mock_asyncio_reader: MagicMock, - mock_asyncio_writer: MagicMock, - ) -> None: - # Arrange - mock_asyncio_writer.can_write_eof.return_value = can_write_eof - - # Act - socket = AsyncioTransportStreamSocketAdapter(mock_asyncio_reader, mock_asyncio_writer) - - # Assert - if can_write_eof: - assert type(socket) is AsyncioTransportHalfCloseableStreamSocketAdapter - else: - assert type(socket) is AsyncioTransportStreamSocketAdapter - - -@pytest.mark.asyncio -class TestAsyncioTransportHalfCloseableStreamSocketAdapter(BaseTestAsyncioTransportBasedStreamSocket): - @pytest.fixture - @staticmethod - def socket( - mock_asyncio_reader: MagicMock, - mock_asyncio_writer: MagicMock, - ) -> AsyncioTransportHalfCloseableStreamSocketAdapter: - mock_asyncio_writer.can_write_eof.return_value = True - return AsyncioTransportHalfCloseableStreamSocketAdapter(mock_asyncio_reader, mock_asyncio_writer) - - @pytest.mark.parametrize("can_write_eof", [False, True], ids=lambda p: f"can_write_eof=={p}") - async def test____dunder_new____check_write_eof_possibility( - self, - can_write_eof: bool, - mock_asyncio_reader: MagicMock, - mock_asyncio_writer: MagicMock, - ) -> None: - # Arrange - mock_asyncio_writer.can_write_eof.return_value = can_write_eof - - # Act & Assert - with pytest.raises(ValueError) if not can_write_eof else contextlib.nullcontext(): - _ = AsyncioTransportHalfCloseableStreamSocketAdapter(mock_asyncio_reader, mock_asyncio_writer) - - async def test____send_eof____write_eof( - self, - socket: AsyncioTransportHalfCloseableStreamSocketAdapter, - mock_asyncio_writer: MagicMock, - ) -> None: - # Arrange - mock_asyncio_writer.write_eof.return_value = None - - # Act - await socket.send_eof() - - # Assert - mock_asyncio_writer.write_eof.assert_called_once_with() - - @pytest.mark.asyncio class TestListenerSocketAdapter(BaseTestTransportStreamSocket, BaseTestSocket): @pytest.fixture(autouse=True) @@ -417,20 +306,6 @@ async def test____dunder_init____default( # Assert assert listener.socket() is mock_tcp_listener_socket - async def test____get_local_address____returns_socket_address( - self, - listener: ListenerSocketAdapter, - mock_tcp_listener_socket: MagicMock, - ) -> None: - # Arrange - - # Act - local_address = listener.get_local_address() - - # Assert - mock_tcp_listener_socket.getsockname.assert_called_once_with() - assert local_address == ("127.0.0.1", 11111) - async def test____is_closing____default( self, listener: ListenerSocketAdapter, @@ -512,12 +387,6 @@ def mock_raw_stream_socket_adapter_cls(mock_stream_socket_adapter: MagicMock, mo return_value=mock_stream_socket_adapter, ) - @pytest.fixture - @staticmethod - def mock_stream_socket_adapter(mock_stream_socket_adapter: MagicMock) -> MagicMock: - mock_stream_socket_adapter.get_remote_address.return_value = ("127.0.0.1", 12345) - return mock_stream_socket_adapter - @pytest.fixture @staticmethod def mock_asyncio_reader_cls(mock_asyncio_reader: MagicMock, mocker: MockerFixture) -> MagicMock: @@ -642,12 +511,6 @@ def mock_transport_stream_socket_adapter_cls(mock_stream_socket_adapter: MagicMo return_value=mock_stream_socket_adapter, ) - @pytest.fixture - @staticmethod - def mock_stream_socket_adapter(mock_stream_socket_adapter: MagicMock) -> MagicMock: - mock_stream_socket_adapter.get_remote_address.return_value = ("127.0.0.1", 12345) - return mock_stream_socket_adapter - @pytest.fixture @staticmethod def mock_asyncio_reader_cls(mock_asyncio_reader: MagicMock, mocker: MockerFixture) -> MagicMock: @@ -804,34 +667,6 @@ async def test____dunder_init____invalid_socket_type( with pytest.raises(ValueError, match=r"^A 'SOCK_STREAM' socket is expected$"): _ = RawStreamSocketAdapter(mock_udp_socket, event_loop) - async def test____get_local_address____returns_socket_address( - self, - socket: RawStreamSocketAdapter, - mock_tcp_socket: MagicMock, - ) -> None: - # Arrange - - # Act - local_address = socket.get_local_address() - - # Assert - mock_tcp_socket.getsockname.assert_called_once_with() - assert local_address == ("127.0.0.1", 11111) - - async def test____get_remote_address____returns_peer_address( - self, - socket: RawStreamSocketAdapter, - mock_tcp_socket: MagicMock, - ) -> None: - # Arrange - - # Act - remote_address = socket.get_remote_address() - - # Assert - mock_tcp_socket.getpeername.assert_called_once_with() - assert remote_address == ("127.0.0.1", 12345) - async def test____is_closing____default( self, socket: RawStreamSocketAdapter, diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py b/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py index 94d64859..b81801bd 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_tasks.py @@ -1,16 +1,14 @@ from __future__ import annotations import asyncio -import math -from collections.abc import Callable from typing import TYPE_CHECKING, Any -from easynetwork_asyncio.tasks import Task, TimeoutHandle, move_on_after, move_on_at, timeout, timeout_at +from easynetwork_asyncio.tasks import Task, TaskUtils import pytest if TYPE_CHECKING: - from unittest.mock import AsyncMock, MagicMock + from unittest.mock import AsyncMock from pytest_mock import MockerFixture @@ -123,7 +121,7 @@ async def test____wait____task_already_done( await task.wait() # Assert - mock_asyncio_wait.assert_not_called() + mock_asyncio_wait.assert_awaited_once_with({mock_asyncio_task}) @pytest.mark.asyncio async def test____join____await_task( @@ -147,194 +145,34 @@ async def test____join____await_task( assert result is mocker.sentinel.task_result -class TestTimeout: - @pytest.fixture - @staticmethod - def mock_asyncio_timeout_handle(mocker: MockerFixture) -> MagicMock: - mock = mocker.NonCallableMagicMock(spec=asyncio.Timeout) - mock.__aenter__.return_value = mock - mock.when.return_value = None - mock.reschedule.return_value = None - mock.expired.return_value = False - return mock - - @pytest.fixture - @staticmethod - def timeout_handle(mock_asyncio_timeout_handle: MagicMock) -> TimeoutHandle: - return TimeoutHandle(mock_asyncio_timeout_handle) - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout", [timeout, move_on_after]) - async def test____timeout____schedule_timeout( - self, - timeout: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: +class TestTaskUtils: + def test____current_asyncio_task____return_current_task(self) -> None: # Arrange - mock_timeout = mocker.patch("asyncio.timeout", return_value=mock_asyncio_timeout_handle) + async def main() -> None: + assert TaskUtils.current_asyncio_task() is asyncio.current_task() - # Act - async with timeout(123456789): - pass - - # Assert - mock_timeout.assert_called_once_with(123456789.0) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout", [timeout, move_on_after]) - async def test____timeout____handle_infinite( - self, - timeout: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout = mocker.patch("asyncio.timeout", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout(math.inf): - pass - - # Assert - mock_timeout.assert_called_once_with(None) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout_at", [timeout_at, move_on_at]) - async def test____timeout_at____schedule_timeout( - self, - timeout_at: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout_at = mocker.patch("asyncio.timeout_at", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout_at(123456789): - pass - - # Assert - mock_timeout_at.assert_called_once_with(123456789.0) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - @pytest.mark.asyncio - @pytest.mark.parametrize("timeout_at", [timeout_at, move_on_at]) - async def test____timeout_at____handle_infinite( - self, - timeout_at: Callable[[float], TimeoutHandle], - mocker: MockerFixture, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_timeout_at = mocker.patch("asyncio.timeout_at", return_value=mock_asyncio_timeout_handle) - - # Act - async with timeout_at(math.inf): - pass - - # Assert - mock_timeout_at.assert_called_once_with(None) - mock_asyncio_timeout_handle.__aenter__.assert_called_once() - - def test____timeout_handle____when____return_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = 123456.789 - - # Act - deadline: float = timeout_handle.when() - - # Assert - assert deadline == 123456.789 - - def test____timeout_handle____when____infinite_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = None - - # Act - deadline: float = timeout_handle.when() - - # Assert - assert deadline == math.inf - - def test____timeout_handle____reschedule____set_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - - # Act - timeout_handle.reschedule(123456.789) - - # Assert - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(123456.789) - - def test____timeout_handle____reschedule____infinite_deadline( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - - # Act - timeout_handle.reschedule(math.inf) - - # Assert - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(None) - - @pytest.mark.parametrize("expected_retval", [False, True]) - def test____expired____expected_return_value( - self, - expected_retval: bool, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.expired.return_value = expected_retval - - # Act - expired: bool = timeout_handle.expired() - - # Assert - assert expired is expected_retval + # Act & Assert + asyncio.run(main()) - def test____deadline_property____set( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: + def test____current_asyncio_task____called_from_callback(self) -> None: # Arrange - mock_asyncio_timeout_handle.when.return_value = 4 + async def main() -> None: + loop = asyncio.get_running_loop() + future: asyncio.Future[None] = loop.create_future() - # Act - timeout_handle.deadline += 12 - - # Assert - mock_asyncio_timeout_handle.when.assert_called_once_with() - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(16) + def callback() -> None: + try: + _ = TaskUtils.current_asyncio_task() + future.set_result(None) + except BaseException as exc: + future.set_exception(exc) + finally: + future.cancel() - def test____deadline_property____delete( - self, - timeout_handle: TimeoutHandle, - mock_asyncio_timeout_handle: MagicMock, - ) -> None: - # Arrange - mock_asyncio_timeout_handle.when.return_value = 4 + loop.call_soon(callback) - # Act - del timeout_handle.deadline + with pytest.raises(RuntimeError, match=r"This function should be called within a task\."): + await future - # Assert - mock_asyncio_timeout_handle.when.assert_not_called() - mock_asyncio_timeout_handle.reschedule.assert_called_once_with(None) + # Act & Assert + asyncio.run(main()) diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_threads.py b/tests/unit_test/test_async/test_asyncio_backend/test_threads.py deleted file mode 100644 index 9e33e656..00000000 --- a/tests/unit_test/test_async/test_asyncio_backend/test_threads.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, Any - -from easynetwork_asyncio.threads import ThreadsPortal - -import pytest - -if TYPE_CHECKING: - from unittest.mock import AsyncMock, MagicMock - - from pytest_mock import MockerFixture - - -@pytest.mark.asyncio -class TestAsyncioThreadsPortal: - @pytest.fixture - @staticmethod - def threads_portal(event_loop: asyncio.AbstractEventLoop) -> ThreadsPortal: - return ThreadsPortal(loop=event_loop) - - async def test____run_coroutine____use_asyncio_run_coroutine_threadsafe( - self, - event_loop: asyncio.AbstractEventLoop, - threads_portal: ThreadsPortal, - mocker: MockerFixture, - ) -> None: - # Arrange - coro_func_stub: AsyncMock = mocker.async_stub() - coro_func_stub.return_value = mocker.sentinel.return_value - mock_run_coroutine_threadsafe: MagicMock = mocker.patch( - "asyncio.run_coroutine_threadsafe", - side_effect=asyncio.run_coroutine_threadsafe, - ) - - # Act - ret_val: Any = None - - def test_thread() -> None: - nonlocal ret_val - - ret_val = threads_portal.run_coroutine( - coro_func_stub, - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - - await asyncio.to_thread(test_thread) - - # Assert - coro_func_stub.assert_awaited_once_with( - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - mock_run_coroutine_threadsafe.assert_called_once_with(mocker.ANY, event_loop) - assert ret_val is mocker.sentinel.return_value - - async def test____run_coroutine____error_if_called_from_event_loop_thread( - self, - threads_portal: ThreadsPortal, - mocker: MockerFixture, - ) -> None: - # Arrange - func_stub: AsyncMock = mocker.async_stub() - mock_run_coroutine_threadsafe: MagicMock = mocker.patch( - "asyncio.run_coroutine_threadsafe", - side_effect=asyncio.run_coroutine_threadsafe, - ) - - # Act - with pytest.raises(RuntimeError): - threads_portal.run_coroutine( - func_stub, - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - - # Assert - func_stub.assert_not_called() - mock_run_coroutine_threadsafe.assert_not_called() - - async def test____run_sync____use_asyncio_event_loop_call_soon_threadsafe( - self, - event_loop: asyncio.AbstractEventLoop, - threads_portal: ThreadsPortal, - mocker: MockerFixture, - ) -> None: - # Arrange - func_stub: MagicMock = mocker.stub() - func_stub.return_value = mocker.sentinel.return_value - mock_loop_call_soon_threadsafe: MagicMock = mocker.patch.object( - event_loop, - "call_soon_threadsafe", - side_effect=event_loop.call_soon_threadsafe, - ) - - # Act - ret_val: Any = None - - def test_thread() -> None: - nonlocal ret_val - - ret_val = threads_portal.run_sync( - func_stub, - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - mock_loop_call_soon_threadsafe.assert_called_once() - - await asyncio.to_thread(test_thread) - - # Assert - func_stub.assert_called_once_with( - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - assert ret_val is mocker.sentinel.return_value - - async def test____run_sync____error_if_called_from_event_loop_thread( - self, - event_loop: asyncio.AbstractEventLoop, - threads_portal: ThreadsPortal, - mocker: MockerFixture, - ) -> None: - # Arrange - func_stub: MagicMock = mocker.stub() - mock_loop_call_soon_threadsafe: MagicMock = mocker.patch.object( - event_loop, - "call_soon_threadsafe", - side_effect=event_loop.call_soon_threadsafe, - ) - - # Act - with pytest.raises(RuntimeError): - threads_portal.run_sync( - func_stub, - mocker.sentinel.arg1, - mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, - ) - - # Assert - func_stub.assert_not_called() - mock_loop_call_soon_threadsafe.assert_not_called() diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index 2c8e003e..13af6a4d 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -15,6 +15,7 @@ check_socket_no_ssl, ensure_datagram_socket_bound, error_from_errno, + exception_with_notes, is_ssl_eof_error, is_ssl_socket, iter_bytes, @@ -23,7 +24,6 @@ remove_traceback_frames_in_place, replace_kwargs, set_reuseport, - transform_future_exception, validate_timeout_delay, wait_socket_available, ) @@ -521,30 +521,30 @@ def test____set_reuseport____not_supported____defined_but_not_implemented( mock_tcp_socket.setsockopt.assert_called_once_with(SOL_SOCKET, SO_REUSEPORT, True) -def test____transform_future_exception____keep_common_exception_as_is() -> None: +def test____exception_with_notes____one_note() -> None: # Arrange - exception = BaseException() + exception = Exception() + note = "A note." # Act - new_exception = transform_future_exception(exception) + returned_exception = exception_with_notes(exception, note) # Assert - assert new_exception is exception + assert returned_exception is exception + assert exception.__notes__ == [note] -@pytest.mark.parametrize("exception", [SystemExit(0), KeyboardInterrupt()], ids=lambda f: type(f).__name__) -def test____transform_future_exception____make_cancelled_error_from_exception(exception: BaseException) -> None: +def test____exception_with_notes____several_notes() -> None: # Arrange - from concurrent.futures import CancelledError + exception = Exception() + notes = ["A note.", "Another note.", "Third one."] # Act - new_exception = transform_future_exception(exception) + returned_exception = exception_with_notes(exception, notes) # Assert - assert type(new_exception) is CancelledError - assert new_exception.__cause__ is exception - assert new_exception.__context__ is exception - assert new_exception.__suppress_context__ + assert returned_exception is exception + assert exception.__notes__ == list(notes) @pytest.mark.parametrize("n", [-1, 0, 2, 2000])