diff --git a/docs/source/conf.py b/docs/source/conf.py index f6c08526..d436eac9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -119,7 +119,6 @@ ("Common Parameters", "params_style"), ("Socket Parameters", "params_style"), ("Connection Parameters", "params_style"), - ("Backend Parameters", "params_style"), ] diff --git a/src/easynetwork/api_async/client/abc.py b/src/easynetwork/api_async/client/abc.py index e5c36cfe..2213bab9 100644 --- a/src/easynetwork/api_async/client/abc.py +++ b/src/easynetwork/api_async/client/abc.py @@ -25,13 +25,12 @@ from ..._typevars import _ReceivedPacketT, _SentPacketT from ...lowlevel import _utils +from ...lowlevel.api_async.backend.factory import current_async_backend from ...lowlevel.socket import SocketAddress if TYPE_CHECKING: from types import TracebackType - from ...lowlevel.api_async.backend.abc import AsyncBackend - class AbstractAsyncNetworkClient(Generic[_SentPacketT, _ReceivedPacketT], metaclass=ABCMeta): """ @@ -220,7 +219,7 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter if timeout is None: timeout = math.inf - timeout_after = self.get_backend().timeout + timeout_after = current_async_backend().timeout while True: try: @@ -230,7 +229,3 @@ async def iter_received_packets(self, *, timeout: float | None = 0) -> AsyncIter return yield packet timeout = elapsed.recompute_timeout(timeout) - - @abstractmethod - def get_backend(self) -> AsyncBackend: - raise NotImplementedError diff --git a/src/easynetwork/api_async/client/tcp.py b/src/easynetwork/api_async/client/tcp.py index c9c61373..601567b4 100644 --- a/src/easynetwork/api_async/client/tcp.py +++ b/src/easynetwork/api_async/client/tcp.py @@ -22,7 +22,7 @@ import dataclasses import errno as _errno import socket as _socket -from collections.abc import Awaitable, Callable, Iterator, Mapping +from collections.abc import Awaitable, Callable, Iterator from typing import TYPE_CHECKING, Any, final, overload try: @@ -36,8 +36,8 @@ from ..._typevars import _ReceivedPacketT, _SentPacketT from ...exceptions import ClientClosedError from ...lowlevel import _utils, constants -from ...lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, ILock -from ...lowlevel.api_async.backend.factory import AsyncBackendFactory +from ...lowlevel.api_async.backend.abc import CancelScope, ILock +from ...lowlevel.api_async.backend.factory import current_async_backend from ...lowlevel.api_async.endpoints.stream import AsyncStreamEndpoint from ...lowlevel.api_async.transports.abc import AsyncStreamTransport from ...lowlevel.socket import ( @@ -79,7 +79,6 @@ class AsyncTCPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPa __slots__ = ( "__endpoint", "__protocol", - "__backend", "__socket_connector", "__socket_proxy", "__receive_lock", @@ -102,8 +101,6 @@ def __init__( ssl_shutdown_timeout: float | None = ..., ssl_shared_lock: bool | None = ..., max_recv_size: int | None = ..., - backend: str | AsyncBackend | None = ..., - backend_kwargs: Mapping[str, Any] | None = ..., ) -> None: ... @@ -120,8 +117,6 @@ def __init__( ssl_shutdown_timeout: float | None = ..., ssl_shared_lock: bool | None = ..., max_recv_size: int | None = ..., - backend: str | AsyncBackend | None = ..., - backend_kwargs: Mapping[str, Any] | None = ..., ) -> None: ... @@ -137,8 +132,6 @@ def __init__( ssl_shutdown_timeout: float | None = None, ssl_shared_lock: bool | None = None, max_recv_size: int | None = None, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> None: """ @@ -172,11 +165,6 @@ def __init__( the same lock instance. max_recv_size: Read buffer size. If not given, a default reasonable value is used. - Backend Parameters: - backend: the backend to use. Automatically determined otherwise. - backend_kwargs: Keyword arguments for backend instanciation. - Ignored if `backend` is already an :class:`.AsyncBackend` instance. - See Also: :ref:`SSL/TLS security considerations ` """ @@ -185,14 +173,13 @@ def __init__( if not isinstance(protocol, StreamProtocol): raise TypeError(f"Expected a StreamProtocol object, got {protocol!r}") - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) + backend = current_async_backend() if max_recv_size is None: max_recv_size = constants.DEFAULT_STREAM_BUFSIZE if not isinstance(max_recv_size, int) or max_recv_size <= 0: raise ValueError("'max_recv_size' must be a strictly positive integer") self.__endpoint: AsyncStreamEndpoint[_SentPacketT, _ReceivedPacketT] | None = None - self.__backend: AsyncBackend = backend self.__socket_proxy: SocketProxy | None = None self.__protocol: StreamProtocol[_SentPacketT, _ReceivedPacketT] = protocol @@ -261,9 +248,9 @@ def _value_or_default(value: float | None, default: float) -> float: raise TypeError("Invalid arguments") self.__socket_connector: _SocketConnector | None = _SocketConnector( - lock=self.__backend.create_lock(), + lock=backend.create_lock(), factory=_utils.make_callback(self.__create_socket, socket_factory), - scope=self.__backend.open_cancel_scope(), + scope=backend.open_cancel_scope(), ) assert ssl_shared_lock is not None # nosec assert_used @@ -477,10 +464,6 @@ def get_remote_address(self) -> SocketAddress: address_family = endpoint.extra(INETSocketAttribute.family) return new_socket_address(remote_address, address_family) - @_utils.inherit_doc(AbstractAsyncNetworkClient) - def get_backend(self) -> AsyncBackend: - return self.__backend - async def __ensure_connected(self) -> AsyncStreamEndpoint[_SentPacketT, _ReceivedPacketT]: if self.__endpoint is None: endpoint_and_proxy = None diff --git a/src/easynetwork/api_async/client/udp.py b/src/easynetwork/api_async/client/udp.py index 06876b9d..28eceff4 100644 --- a/src/easynetwork/api_async/client/udp.py +++ b/src/easynetwork/api_async/client/udp.py @@ -22,14 +22,14 @@ import dataclasses as _dataclasses import errno as _errno import socket as _socket -from collections.abc import Awaitable, Callable, Iterator, Mapping +from collections.abc import Awaitable, Callable, Iterator from typing import Any, final, overload from ..._typevars import _ReceivedPacketT, _SentPacketT from ...exceptions import ClientClosedError from ...lowlevel import _utils, constants -from ...lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, ILock -from ...lowlevel.api_async.backend.factory import AsyncBackendFactory +from ...lowlevel.api_async.backend.abc import CancelScope, ILock +from ...lowlevel.api_async.backend.factory import current_async_backend from ...lowlevel.api_async.endpoints.datagram import AsyncDatagramEndpoint from ...lowlevel.api_async.transports.abc import AsyncDatagramTransport from ...lowlevel.socket import INETSocketAttribute, SocketAddress, SocketProxy, new_socket_address @@ -61,7 +61,6 @@ class AsyncUDPNetworkClient(AbstractAsyncNetworkClient[_SentPacketT, _ReceivedPa __slots__ = ( "__endpoint", "__socket_proxy", - "__backend", "__socket_connector", "__receive_lock", "__send_lock", @@ -77,8 +76,6 @@ def __init__( *, local_address: tuple[str, int] | None = ..., family: int = ..., - backend: str | AsyncBackend | None = ..., - backend_kwargs: Mapping[str, Any] | None = ..., ) -> None: ... @@ -88,9 +85,6 @@ def __init__( socket: _socket.socket, /, protocol: DatagramProtocol[_SentPacketT, _ReceivedPacketT], - *, - backend: str | AsyncBackend | None = ..., - backend_kwargs: Mapping[str, Any] | None = ..., ) -> None: ... @@ -99,9 +93,6 @@ def __init__( __arg: tuple[str, int] | _socket.socket, /, protocol: DatagramProtocol[_SentPacketT, _ReceivedPacketT], - *, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> None: """ @@ -119,22 +110,16 @@ def __init__( Socket Parameters: socket: An already connected UDP :class:`socket.socket`. If `socket` is given, none of and `local_address` and `reuse_port` should be specified. - - Backend Parameters: - backend: the backend to use. Automatically determined otherwise. - backend_kwargs: Keyword arguments for backend instanciation. - Ignored if `backend` is already an :class:`.AsyncBackend` instance. """ super().__init__() - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) + backend = current_async_backend() if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__protocol: DatagramProtocol[_SentPacketT, _ReceivedPacketT] = protocol self.__endpoint: AsyncDatagramEndpoint[_SentPacketT, _ReceivedPacketT] | None = None - self.__backend: AsyncBackend = backend self.__socket_proxy: SocketProxy | None = None socket_factory: Callable[[], Awaitable[AsyncDatagramTransport]] @@ -155,9 +140,9 @@ def __init__( raise TypeError("Invalid arguments") self.__socket_connector: _SocketConnector | None = _SocketConnector( - lock=self.__backend.create_lock(), + lock=backend.create_lock(), factory=_utils.make_callback(self.__create_socket, socket_factory), - scope=self.__backend.open_cancel_scope(), + scope=backend.open_cancel_scope(), ) self.__receive_lock: ILock = backend.create_lock() self.__send_lock: ILock = backend.create_lock() @@ -323,10 +308,6 @@ def get_remote_address(self) -> SocketAddress: address_family = endpoint.extra(INETSocketAttribute.family) return new_socket_address(remote_address, address_family) - @_utils.inherit_doc(AbstractAsyncNetworkClient) - def get_backend(self) -> AsyncBackend: - return self.__backend - async def __ensure_connected(self) -> AsyncDatagramEndpoint[_SentPacketT, _ReceivedPacketT]: if self.__endpoint is None: endpoint_and_proxy = None diff --git a/src/easynetwork/api_async/server/abc.py b/src/easynetwork/api_async/server/abc.py index b9b12fe1..2e998b9a 100644 --- a/src/easynetwork/api_async/server/abc.py +++ b/src/easynetwork/api_async/server/abc.py @@ -30,8 +30,6 @@ if TYPE_CHECKING: from types import TracebackType - from ...lowlevel.api_async.backend.abc import AsyncBackend - class SupportsEventSet(Protocol): """ @@ -116,7 +114,3 @@ def get_addresses(self) -> Sequence[SocketAddress]: A sequence of network socket address. If the server is not serving (:meth:`is_serving` returns :data:`False`), an empty sequence is returned. """ - - @abstractmethod - def get_backend(self) -> AsyncBackend: - raise NotImplementedError diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index fdcc4615..8b97a42c 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -30,8 +30,8 @@ from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError from ...lowlevel import _asyncgen, _utils, constants from ...lowlevel._final import runtime_final_class -from ...lowlevel.api_async.backend.factory import AsyncBackendFactory -from ...lowlevel.api_async.servers import stream as lowlevel_stream_server +from ...lowlevel.api_async.backend.factory import current_async_backend +from ...lowlevel.api_async.servers import stream as _stream_server from ...lowlevel.socket import ( INETSocketAttribute, ISocket, @@ -49,7 +49,7 @@ if TYPE_CHECKING: import ssl as _typing_ssl - from ...lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, IEvent, Task, TaskGroup + from ...lowlevel.api_async.backend.abc import CancelScope, IEvent, Task, TaskGroup from ...lowlevel.api_async.transports.abc import AsyncListener, AsyncStreamTransport @@ -59,7 +59,6 @@ class AsyncTCPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp """ __slots__ = ( - "__backend", "__servers", "__listeners_factory", "__listeners_factory_scope", @@ -90,8 +89,6 @@ def __init__( max_recv_size: int | None = None, log_client_connection: bool | None = None, logger: logging.Logger | None = None, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, ) -> None: """ Parameters: @@ -124,11 +121,6 @@ def __init__( (This log will always be available in :data:`logging.DEBUG` level.) logger: If given, the logger instance to use. - Backend Parameters: - backend: the backend to use. Automatically determined otherwise. - backend_kwargs: Keyword arguments for backend instanciation. - Ignored if `backend` is already an :class:`.AsyncBackend` instance. - See Also: :ref:`SSL/TLS security considerations ` """ @@ -139,7 +131,7 @@ def __init__( if not isinstance(request_handler, AsyncStreamRequestHandler): raise TypeError(f"Expected an AsyncStreamRequestHandler object, got {request_handler!r}") - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) + backend = current_async_backend() if backlog is None: backlog = 100 @@ -183,11 +175,10 @@ def _value_or_default(value: float | None, default: float) -> float: ) self.__listeners_factory_scope: CancelScope | None = None - self.__backend: AsyncBackend = backend - self.__servers: tuple[lowlevel_stream_server.AsyncStreamServer[_RequestT, _ResponseT], ...] | None = None + self.__servers: tuple[_stream_server.AsyncStreamServer[_RequestT, _ResponseT], ...] | None = None self.__protocol: StreamProtocol[_ResponseT, _RequestT] = protocol self.__request_handler: AsyncStreamRequestHandler[_RequestT, _ResponseT] = request_handler - self.__is_shutdown: IEvent = self.__backend.create_event() + self.__is_shutdown: IEvent = backend.create_event() self.__is_shutdown.set() self.__shutdown_asked: bool = False self.__max_recv_size: int = max_recv_size @@ -239,7 +230,7 @@ async def __close_servers(self) -> None: exit_stack.push_async_callback(server_task.wait) del server_task - await self.__backend.cancel_shielded_coro_yield() + await current_async_backend().cancel_shielded_coro_yield() @_utils.inherit_doc(AbstractAsyncNetworkServer) async def shutdown(self) -> None: @@ -260,7 +251,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Wake up server if not self.__is_shutdown.is_set(): raise ServerAlreadyRunning("Server is already running") - self.__is_shutdown = is_shutdown = self.__backend.create_event() + self.__is_shutdown = is_shutdown = current_async_backend().create_event() server_exit_stack.callback(is_shutdown.set) ################ @@ -271,8 +262,8 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> raise ServerClosedError("Closed server") listeners: list[AsyncListener[AsyncStreamTransport]] = [] try: - with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: - await self.__backend.coro_yield() + with current_async_backend().open_cancel_scope() as self.__listeners_factory_scope: + await current_async_backend().coro_yield() listeners.extend(await self.__listeners_factory()) if self.__listeners_factory_scope.cancelled_caught(): raise ServerClosedError("Closed server") @@ -281,12 +272,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> if not listeners: raise OSError("empty listeners list") self.__servers = tuple( - lowlevel_stream_server.AsyncStreamServer( - listener, - self.__protocol, - max_recv_size=self.__max_recv_size, - backend=self.__backend, - ) + _stream_server.AsyncStreamServer(listener, self.__protocol, max_recv_size=self.__max_recv_size) for listener in listeners ) del listeners @@ -307,7 +293,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Setup task group self.__active_tasks = 0 server_exit_stack.callback(self.__servers_tasks.clear) - task_group = await server_exit_stack.enter_async_context(self.__backend.create_task_group()) + task_group = await server_exit_stack.enter_async_context(current_async_backend().create_task_group()) server_exit_stack.callback(self.__logger.info, "Server loop break, waiting for remaining tasks...") ################## @@ -322,7 +308,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> ############## # Main loop - self.__mainloop_task = task_group.start_soon(self.__backend.sleep_forever) + self.__mainloop_task = task_group.start_soon(current_async_backend().sleep_forever) if self.__shutdown_asked: self.__mainloop_task.cancel() try: @@ -334,7 +320,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> async def __serve( self, - server: lowlevel_stream_server.AsyncStreamServer[_RequestT, _ResponseT], + server: _stream_server.AsyncStreamServer[_RequestT, _ResponseT], task_group: TaskGroup, ) -> NoReturn: self.__attach_server() @@ -346,7 +332,7 @@ async def __serve( async def __client_coroutine( self, - lowlevel_client: lowlevel_stream_server.AsyncStreamClient[_ResponseT], + lowlevel_client: _stream_server.AsyncStreamClient[_ResponseT], ) -> AsyncGenerator[None, _RequestT]: async with contextlib.AsyncExitStack() as client_exit_stack: self.__attach_server() @@ -368,8 +354,7 @@ async def __client_coroutine( client_exit_stack.callback(self.__set_socket_linger_if_not_closed, lowlevel_client.extra(INETSocketAttribute.socket)) logger: logging.Logger = self.__logger - backend: AsyncBackend = self.__backend - client = _ConnectedClientAPI(client_address, backend, lowlevel_client, logger) + client = _ConnectedClientAPI(client_address, lowlevel_client, logger) del lowlevel_client @@ -401,6 +386,7 @@ async def disconnect_client() -> None: del client_exit_stack + backend = current_async_backend() try: action: _asyncgen.AsyncGenAction[None, _RequestT] while not client.is_closing(): @@ -478,10 +464,6 @@ def get_addresses(self) -> Sequence[SocketAddress]: if not server.is_closing() ) - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def get_backend(self) -> AsyncBackend: - return self.__backend - @property def sockets(self) -> Sequence[SocketProxy]: """The listeners sockets. Read-only attribute.""" @@ -510,13 +492,12 @@ class _ConnectedClientAPI(AsyncStreamClient[_ResponseT]): def __init__( self, address: SocketAddress, - backend: AsyncBackend, - client: lowlevel_stream_server.AsyncStreamClient[_ResponseT], + client: _stream_server.AsyncStreamClient[_ResponseT], logger: logging.Logger, ) -> None: - self.__client: lowlevel_stream_server.AsyncStreamClient[_ResponseT] = client + self.__client: _stream_server.AsyncStreamClient[_ResponseT] = client self.__closed: bool = False - self.__send_lock = backend.create_lock() + self.__send_lock = current_async_backend().create_lock() self.__logger: logging.Logger = logger self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket)) self.__address: SocketAddress = address diff --git a/src/easynetwork/api_async/server/udp.py b/src/easynetwork/api_async/server/udp.py index 63205dd9..25b2d4db 100644 --- a/src/easynetwork/api_async/server/udp.py +++ b/src/easynetwork/api_async/server/udp.py @@ -29,8 +29,8 @@ from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError from ...lowlevel import _asyncgen, _utils from ...lowlevel._final import runtime_final_class -from ...lowlevel.api_async.backend.factory import AsyncBackendFactory -from ...lowlevel.api_async.servers import datagram as lowlevel_datagram_server +from ...lowlevel.api_async.backend.factory import current_async_backend +from ...lowlevel.api_async.servers import datagram as _datagram_server from ...lowlevel.api_async.transports.abc import AsyncDatagramListener from ...lowlevel.socket import INETSocketAttribute, SocketAddress, SocketProxy, new_socket_address from ...protocol import DatagramProtocol @@ -38,7 +38,7 @@ from .handler import AsyncDatagramClient, AsyncDatagramRequestHandler, INETClientAttribute if TYPE_CHECKING: - from ...lowlevel.api_async.backend.abc import AsyncBackend, CancelScope, IEvent, ILock, Task, TaskGroup + from ...lowlevel.api_async.backend.abc import CancelScope, IEvent, ILock, Task, TaskGroup class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _ResponseT]): @@ -47,7 +47,6 @@ class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp """ __slots__ = ( - "__backend", "__servers", "__listeners_factory", "__listeners_factory_scope", @@ -71,8 +70,6 @@ def __init__( *, reuse_port: bool = False, logger: logging.Logger | None = None, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, ) -> None: """ Parameters: @@ -88,11 +85,6 @@ def __init__( are bound to, so long as they all set this flag when being created. This option is not supported on Windows. logger: If given, the logger instance to use. - - Backend Parameters: - backend: the backend to use. Automatically determined otherwise. - backend_kwargs: Keyword arguments for backend instanciation. - Ignored if `backend` is already an :class:`.AsyncBackend` instance. """ super().__init__() @@ -101,7 +93,7 @@ def __init__( if not isinstance(request_handler, AsyncDatagramRequestHandler): raise TypeError(f"Expected an AsyncDatagramRequestHandler object, got {request_handler!r}") - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) + backend = current_async_backend() self.__listeners_factory: Callable[[], Coroutine[Any, Any, Sequence[AsyncDatagramListener[tuple[Any, ...]]]]] | None self.__listeners_factory = _utils.make_callback( @@ -112,23 +104,22 @@ def __init__( ) self.__listeners_factory_scope: CancelScope | None = None - self.__backend: AsyncBackend = backend - self.__servers: tuple[lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], ...] | None + self.__servers: tuple[_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], ...] | None self.__servers = None self.__protocol: DatagramProtocol[_ResponseT, _RequestT] = protocol self.__request_handler: AsyncDatagramRequestHandler[_RequestT, _ResponseT] = request_handler - self.__is_shutdown: IEvent = self.__backend.create_event() + self.__is_shutdown: IEvent = backend.create_event() self.__is_shutdown.set() self.__shutdown_asked: bool = False self.__servers_tasks: deque[Task[NoReturn]] = deque() self.__mainloop_task: Task[NoReturn] | None = None self.__logger: logging.Logger = logger or logging.getLogger(__name__) self.__clients_cache: defaultdict[ - lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], weakref.WeakValueDictionary[SocketAddress, _ClientAPI[_ResponseT]], ] self.__send_locks_cache: defaultdict[ - lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], weakref.WeakValueDictionary[SocketAddress, ILock], ] self.__clients_cache = defaultdict(weakref.WeakValueDictionary) @@ -162,7 +153,7 @@ async def __close_servers(self) -> None: self.__mainloop_task.cancel() exit_stack.push_async_callback(self.__mainloop_task.wait) - await self.__backend.cancel_shielded_coro_yield() + await current_async_backend().cancel_shielded_coro_yield() @_utils.inherit_doc(AbstractAsyncNetworkServer) async def shutdown(self) -> None: @@ -183,7 +174,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Wake up server if not self.__is_shutdown.is_set(): raise ServerAlreadyRunning("Server is already running") - self.__is_shutdown = is_shutdown = self.__backend.create_event() + self.__is_shutdown = is_shutdown = current_async_backend().create_event() server_exit_stack.callback(is_shutdown.set) ################ @@ -194,8 +185,8 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> raise ServerClosedError("Closed server") listeners: list[AsyncDatagramListener[tuple[Any, ...]]] = [] try: - with self.__backend.open_cancel_scope() as self.__listeners_factory_scope: - await self.__backend.coro_yield() + with current_async_backend().open_cancel_scope() as self.__listeners_factory_scope: + await current_async_backend().coro_yield() listeners.extend(await self.__listeners_factory()) if self.__listeners_factory_scope.cancelled_caught(): raise ServerClosedError("Closed server") @@ -203,14 +194,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> self.__listeners_factory_scope = None if not listeners: raise OSError("empty listeners list") - self.__servers = tuple( - lowlevel_datagram_server.AsyncDatagramServer( - listener, - self.__protocol, - backend=self.__backend, - ) - for listener in listeners - ) + self.__servers = tuple(_datagram_server.AsyncDatagramServer(listener, self.__protocol) for listener in listeners) del listeners ################### @@ -230,7 +214,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> # Setup task group server_exit_stack.callback(self.__servers_tasks.clear) - task_group: TaskGroup = await server_exit_stack.enter_async_context(self.__backend.create_task_group()) + task_group: TaskGroup = await server_exit_stack.enter_async_context(current_async_backend().create_task_group()) server_exit_stack.callback(self.__logger.info, "Server loop break, waiting for remaining tasks...") ################## @@ -245,7 +229,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> ############## # Main loop - self.__mainloop_task = task_group.start_soon(self.__backend.sleep_forever) + self.__mainloop_task = task_group.start_soon(current_async_backend().sleep_forever) if self.__shutdown_asked: self.__mainloop_task.cancel() try: @@ -257,7 +241,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> async def __serve( self, - server: lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + server: _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], task_group: TaskGroup, ) -> NoReturn: async with contextlib.aclosing(server): @@ -266,24 +250,13 @@ async def __serve( async def __datagram_received_coroutine( self, address: tuple[Any, ...], - server: lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + server: _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], ) -> AsyncGenerator[None, _RequestT]: address = new_socket_address(address, server.extra(INETSocketAttribute.family)) with self.__suppress_and_log_remaining_exception(client_address=address): - send_locks_cache = self.__send_locks_cache[server] - try: - send_lock = send_locks_cache[address] - except KeyError: - send_locks_cache[address] = send_lock = self.__backend.create_lock() - - clients_cache = self.__clients_cache[server] - try: - client = clients_cache[address] - except KeyError: - clients_cache[address] = client = _ClientAPI(address, server, send_lock, self.__logger) - - async with contextlib.aclosing(self.__request_handler.handle(client)) as request_handler_generator: - del client, send_lock + async with contextlib.aclosing( + self.__request_handler.handle(self.__get_client(server, address)) + ) as request_handler_generator: try: await anext(request_handler_generator) except StopAsyncIteration: @@ -320,6 +293,30 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress) self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc) self.__logger.error("-" * 40) + def __get_client_lock( + self, + server: _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + address: SocketAddress, + ) -> ILock: + send_locks_cache = self.__send_locks_cache[server] + try: + send_lock = send_locks_cache[address] + except KeyError: + send_locks_cache[address] = send_lock = current_async_backend().create_lock() + return send_lock + + def __get_client( + self, + server: _datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + address: SocketAddress, + ) -> _ClientAPI[_ResponseT]: + clients_cache = self.__clients_cache[server] + try: + client = clients_cache[address] + except KeyError: + clients_cache[address] = client = _ClientAPI(address, server, self.__get_client_lock(server, address), self.__logger) + return client + @_utils.inherit_doc(AbstractAsyncNetworkServer) def get_addresses(self) -> Sequence[SocketAddress]: if (servers := self.__servers) is None: @@ -330,10 +327,6 @@ def get_addresses(self) -> Sequence[SocketAddress]: if not server.is_closing() ) - @_utils.inherit_doc(AbstractAsyncNetworkServer) - def get_backend(self) -> AsyncBackend: - return self.__backend - @property def sockets(self) -> Sequence[SocketProxy]: """The listeners sockets. Read-only attribute.""" @@ -362,12 +355,12 @@ class _ClientAPI(AsyncDatagramClient[_ResponseT]): def __init__( self, address: SocketAddress, - server: lowlevel_datagram_server.AsyncDatagramServer[Any, _ResponseT, Any], + server: _datagram_server.AsyncDatagramServer[Any, _ResponseT, Any], send_lock: ILock, logger: logging.Logger, ) -> None: super().__init__() - self.__server_ref: weakref.ref[lowlevel_datagram_server.AsyncDatagramServer[Any, _ResponseT, Any]] = weakref.ref(server) + self.__server_ref: weakref.ref[_datagram_server.AsyncDatagramServer[Any, _ResponseT, Any]] = weakref.ref(server) self.__socket_proxy: SocketProxy = SocketProxy(server.extra(INETSocketAttribute.socket)) self.__h: int | None = None self.__send_lock: ILock = send_lock @@ -398,7 +391,7 @@ async def send_packet(self, packet: _ResponseT, /) -> None: _utils.check_real_socket_state(self.__socket_proxy) self.__logger.debug("Datagram successfully sent to %s.", self.__address) - def __check_closed(self) -> lowlevel_datagram_server.AsyncDatagramServer[Any, _ResponseT, Any]: + def __check_closed(self) -> _datagram_server.AsyncDatagramServer[Any, _ResponseT, Any]: server = self.__server_ref() if server is None or server.is_closing(): raise ClientClosedError("Closed client") diff --git a/src/easynetwork/api_sync/server/_base.py b/src/easynetwork/api_sync/server/_base.py index 18770d60..602f2c4e 100644 --- a/src/easynetwork/api_sync/server/_base.py +++ b/src/easynetwork/api_sync/server/_base.py @@ -21,14 +21,15 @@ import concurrent.futures import contextlib import threading as _threading -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any, NoReturn from ...api_async.server.abc import SupportsEventSet from ...exceptions import ServerAlreadyRunning, ServerClosedError from ...lowlevel import _utils from ...lowlevel._lock import ForkSafeLock -from ...lowlevel.api_async.backend.abc import ThreadsPortal +from ...lowlevel.api_async.backend.abc import AsyncBackend, ThreadsPortal +from ...lowlevel.api_async.backend.factory import AsyncBackendFactory as _Factory, current_async_backend as _current_backend from ...lowlevel.socket import SocketAddress from .abc import AbstractNetworkServer @@ -38,18 +39,22 @@ class BaseStandaloneNetworkServerImpl(AbstractNetworkServer): __slots__ = ( - "__server", + "__server_factory", + "__private_server", + "__backend", "__close_lock", "__bootstrap_lock", - "__threads_portal", + "__private_threads_portal", "__is_shutdown", "__is_closed", ) - def __init__(self, server: AbstractAsyncNetworkServer) -> None: + def __init__(self, backend: str, server_factory: Callable[[], AbstractAsyncNetworkServer]) -> None: super().__init__() - self.__server: AbstractAsyncNetworkServer = server - self.__threads_portal: ThreadsPortal | None = None + self.__backend: AsyncBackend = _Factory.get_backend(backend) + self.__server_factory: Callable[[], AbstractAsyncNetworkServer] = server_factory + self.__private_server: AbstractAsyncNetworkServer | None = None + self.__private_threads_portal: ThreadsPortal | None = None self.__is_shutdown = _threading.Event() self.__is_shutdown.set() self.__is_closed = _threading.Event() @@ -58,43 +63,43 @@ def __init__(self, server: AbstractAsyncNetworkServer) -> None: @_utils.inherit_doc(AbstractNetworkServer) def is_serving(self) -> bool: - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError): - return portal.run_sync(self.__server.is_serving) + return portal.run_sync(server.is_serving) return False @_utils.inherit_doc(AbstractNetworkServer) def server_close(self) -> None: with self.__close_lock.get(), contextlib.ExitStack() as stack, contextlib.suppress(RuntimeError): - if (portal := self._portal) is not None: + stack.callback(self.__is_closed.set) + + # Ensure we are not in the interval between the server shutdown and the scheduler shutdown + stack.callback(self.__is_shutdown.wait) + + if (server := self._server) is not None and (portal := self._portal) is not None: with contextlib.suppress(concurrent.futures.CancelledError): - portal.run_coroutine(self.__server.server_close) - else: - stack.callback(self.__is_closed.set) - self.__is_shutdown.wait() # Ensure we are not in the interval between the server shutdown and the scheduler shutdown - backend = self.__server.get_backend() - backend.bootstrap(self.__server.server_close) + portal.run_coroutine(server.server_close) @_utils.inherit_doc(AbstractNetworkServer) def shutdown(self, timeout: float | None = None) -> None: - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError): # If shutdown() have been cancelled, that means the scheduler itself is shutting down, and this is what we want if timeout is None: - portal.run_coroutine(self.__server.shutdown) + portal.run_coroutine(server.shutdown) else: elapsed = _utils.ElapsedTime() try: with elapsed: - portal.run_coroutine(self.__do_shutdown_with_timeout, timeout) + portal.run_coroutine(self.__do_shutdown_with_timeout, server, timeout) finally: timeout = elapsed.recompute_timeout(timeout) self.__is_shutdown.wait(timeout) - async def __do_shutdown_with_timeout(self, timeout_delay: float) -> None: - backend = self.__server.get_backend() - with backend.move_on_after(timeout_delay): - await self.__server.shutdown() + @staticmethod + async def __do_shutdown_with_timeout(server: AbstractAsyncNetworkServer, timeout_delay: float) -> None: + with _current_backend().move_on_after(timeout_delay): + await server.shutdown() def serve_forever( self, @@ -114,7 +119,7 @@ def serve_forever( ServerAlreadyRunning: Another task already called :meth:`serve_forever`. """ - backend = self.__server.get_backend() + backend = self.__backend with contextlib.ExitStack() as server_exit_stack, contextlib.suppress(backend.get_cancelled_exc_class()): # locks_stack is used to acquire locks until # serve_forever() coroutine creates the thread portal @@ -131,36 +136,42 @@ def serve_forever( self.__is_shutdown.clear() server_exit_stack.callback(self.__is_shutdown.set) - def reset_threads_portal() -> None: - self.__threads_portal = None + def reset_values() -> None: + self.__private_threads_portal = None + self.__private_server = None def acquire_bootstrap_lock() -> None: locks_stack.enter_context(self.__bootstrap_lock.get()) - server_exit_stack.callback(reset_threads_portal) + server_exit_stack.callback(reset_values) server_exit_stack.callback(acquire_bootstrap_lock) async def serve_forever() -> NoReturn: - async with backend.create_threads_portal() as self.__threads_portal: + async with ( + self.__server_factory() as self.__private_server, + backend.create_threads_portal() as self.__private_threads_portal, + ): + server = self.__private_server # Initialization finished; release the locks locks_stack.close() - await self.__server.serve_forever(is_up_event=is_up_event) + await server.serve_forever(is_up_event=is_up_event) backend.bootstrap(serve_forever, runner_options=runner_options) @_utils.inherit_doc(AbstractNetworkServer) def get_addresses(self) -> Sequence[SocketAddress]: - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError): - return portal.run_sync(self.__server.get_addresses) + return portal.run_sync(server.get_addresses) return () @property - def _server(self) -> AbstractAsyncNetworkServer: - return self.__server + def _server(self) -> AbstractAsyncNetworkServer | None: + with self.__bootstrap_lock.get(): + return self.__private_server @property def _portal(self) -> ThreadsPortal | None: with self.__bootstrap_lock.get(): - return self.__threads_portal + return self.__private_threads_portal diff --git a/src/easynetwork/api_sync/server/tcp.py b/src/easynetwork/api_sync/server/tcp.py index ed823b7c..e6821a94 100644 --- a/src/easynetwork/api_sync/server/tcp.py +++ b/src/easynetwork/api_sync/server/tcp.py @@ -21,7 +21,7 @@ ] import contextlib -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Generic from ..._typevars import _RequestT, _ResponseT @@ -35,7 +35,6 @@ from ssl import SSLContext as _SSLContext from ...api_async.server.handler import AsyncStreamRequestHandler - from ...lowlevel.api_async.backend.abc import AsyncBackend from ...protocol import StreamProtocol @@ -54,7 +53,7 @@ def __init__( port: int, protocol: StreamProtocol[_ResponseT, _RequestT], request_handler: AsyncStreamRequestHandler[_RequestT, _ResponseT], - backend: str | AsyncBackend = "asyncio", + backend: str = "asyncio", *, ssl: _SSLContext | None = None, ssl_handshake_timeout: float | None = None, @@ -64,7 +63,6 @@ def __init__( max_recv_size: int | None = None, log_client_connection: bool | None = None, logger: logging.Logger | None = None, - backend_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> None: """ @@ -72,13 +70,11 @@ def __init__( Note: The backend interface must be explicitly given. It defaults to ``asyncio``. - - :exc:`ValueError` is raised if :data:`None` is given. """ - if backend is None: - raise ValueError("You must explicitly give a backend name or instance") super().__init__( - AsyncTCPNetworkServer( + backend, + _utils.make_callback( + AsyncTCPNetworkServer, # type: ignore[arg-type] host=host, port=port, protocol=protocol, @@ -91,10 +87,8 @@ def __init__( max_recv_size=max_recv_size, log_client_connection=log_client_connection, logger=logger, - backend=backend, - backend_kwargs=backend_kwargs, **kwargs, - ) + ), ) def stop_listening(self) -> None: @@ -106,26 +100,21 @@ def stop_listening(self) -> None: Further calls to :meth:`is_serving` will return :data:`False`. """ - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError): - portal.run_sync(self._server.stop_listening) + portal.run_sync(server.stop_listening) @property @_utils.inherit_doc(AsyncTCPNetworkServer) def sockets(self) -> Sequence[SocketProxy]: - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError): - sockets = portal.run_sync(lambda: self._server.sockets) + sockets = portal.run_sync(lambda: server.sockets) return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets) return () - @property - @_utils.inherit_doc(AsyncTCPNetworkServer) - def logger(self) -> logging.Logger: - return self._server.logger - if TYPE_CHECKING: @property - def _server(self) -> AsyncTCPNetworkServer[_RequestT, _ResponseT]: + def _server(self) -> AsyncTCPNetworkServer[_RequestT, _ResponseT] | None: ... diff --git a/src/easynetwork/api_sync/server/udp.py b/src/easynetwork/api_sync/server/udp.py index 74707dba..19c14583 100644 --- a/src/easynetwork/api_sync/server/udp.py +++ b/src/easynetwork/api_sync/server/udp.py @@ -21,7 +21,7 @@ ] import contextlib -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Generic from ..._typevars import _RequestT, _ResponseT @@ -34,7 +34,6 @@ import logging from ...api_async.server.handler import AsyncDatagramRequestHandler - from ...lowlevel.api_async.backend.abc import AsyncBackend from ...protocol import DatagramProtocol @@ -53,11 +52,10 @@ def __init__( port: int, protocol: DatagramProtocol[_ResponseT, _RequestT], request_handler: AsyncDatagramRequestHandler[_RequestT, _ResponseT], - backend: str | AsyncBackend = "asyncio", + backend: str = "asyncio", *, reuse_port: bool = False, logger: logging.Logger | None = None, - backend_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> None: """ @@ -65,41 +63,32 @@ def __init__( Note: The backend interface must be explicitly given. It defaults to ``asyncio``. - - :exc:`ValueError` is raised if :data:`None` is given. """ - if backend is None: - raise ValueError("You must explicitly give a backend name or instance") super().__init__( - AsyncUDPNetworkServer( + backend, + _utils.make_callback( + AsyncUDPNetworkServer, # type: ignore[arg-type] host=host, port=port, protocol=protocol, request_handler=request_handler, reuse_port=reuse_port, logger=logger, - backend=backend, - backend_kwargs=backend_kwargs, **kwargs, - ) + ), ) @property @_utils.inherit_doc(AsyncUDPNetworkServer) def sockets(self) -> Sequence[SocketProxy]: - if (portal := self._portal) is not None: + if (portal := self._portal) is not None and (server := self._server) is not None: with contextlib.suppress(RuntimeError): - sockets = portal.run_sync(lambda: self._server.sockets) + sockets = portal.run_sync(lambda: server.sockets) return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets) return () - @property - @_utils.inherit_doc(AsyncUDPNetworkServer) - def logger(self) -> logging.Logger: - return self._server.logger - if TYPE_CHECKING: @property - def _server(self) -> AsyncUDPNetworkServer[_RequestT, _ResponseT]: + def _server(self) -> AsyncUDPNetworkServer[_RequestT, _ResponseT] | None: ... diff --git a/src/easynetwork/lowlevel/api_async/backend/factory.py b/src/easynetwork/lowlevel/api_async/backend/factory.py index af02b10b..4a6efa71 100644 --- a/src/easynetwork/lowlevel/api_async/backend/factory.py +++ b/src/easynetwork/lowlevel/api_async/backend/factory.py @@ -16,160 +16,99 @@ from __future__ import annotations -__all__ = ["AsyncBackendFactory"] +__all__ = ["AsyncBackendFactory", "current_async_backend"] import functools -import inspect -from collections import Counter -from collections.abc import Mapping -from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Final, final +import threading +from collections import deque +from collections.abc import Callable +from typing import Final, final +from ....exceptions import UnsupportedOperation +from ... import _lock from ..._final import runtime_final_class -from ._sniffio_helpers import current_async_library as _sniffio_current_async_library +from . import _sniffio_helpers from .abc import AsyncBackend -if TYPE_CHECKING: - from importlib.metadata import EntryPoint - @final @runtime_final_class class AsyncBackendFactory: - GROUP_NAME: Final[str] = "easynetwork.async.backends" - __BACKEND: str | type[AsyncBackend] | None = None - __BACKEND_EXTENSIONS: Final[dict[str, type[AsyncBackend]]] = {} - - @staticmethod - def get_default_backend(guess_current_async_library: bool = True) -> type[AsyncBackend]: - backend: str | type[AsyncBackend] | None = AsyncBackendFactory.__BACKEND - if isinstance(backend, type): - return backend - if backend is None: - if guess_current_async_library: - backend = _sniffio_current_async_library() # must raise if not recognized - else: - backend = "asyncio" - return AsyncBackendFactory.__get_backend_cls( - backend, - "Running library {name!r} misses the backend implementation", - extended=True, - ) - - @staticmethod - def set_default_backend(backend: str | type[AsyncBackend] | None) -> None: - match backend: - case type() if not issubclass(backend, AsyncBackend) or inspect.isabstract(backend): - raise TypeError(f"Invalid backend class: {backend!r}") - case type() | None: - pass - case str(): - AsyncBackendFactory.__get_backend_cls(backend, extended=False) - case _: # pragma: no cover - raise TypeError(f"Invalid argument: {backend!r}") - - AsyncBackendFactory.__BACKEND = backend - - @staticmethod - def extend(backend_name: str, backend_cls: type[AsyncBackend] | None) -> None: - default_backend_cls = AsyncBackendFactory.__get_backend_cls(backend_name, extended=False) - if backend_cls is None or backend_cls is default_backend_cls: - AsyncBackendFactory.__BACKEND_EXTENSIONS.pop(backend_name, None) - return - if not issubclass(backend_cls, default_backend_cls): - raise TypeError(f"Invalid backend class (not a subclass of {default_backend_cls!r}): {backend_cls!r}") - AsyncBackendFactory.__BACKEND_EXTENSIONS[backend_name] = backend_cls - - @staticmethod - def new(backend: str | None = None, /, **kwargs: Any) -> AsyncBackend: - backend_cls: type[AsyncBackend] - if backend is None: - backend_cls = AsyncBackendFactory.get_default_backend(guess_current_async_library=True) - else: - backend_cls = AsyncBackendFactory.__get_backend_cls(backend, extended=True) - return backend_cls(**kwargs) - - @staticmethod - def ensure(backend: str | AsyncBackend | None, kwargs: Mapping[str, Any] | None = None) -> AsyncBackend: - if not isinstance(backend, AsyncBackend): - if kwargs is None: - kwargs = {} - backend = AsyncBackendFactory.new(backend, **kwargs) - return backend - - @staticmethod - def get_all_backends(*, extended: bool = True) -> MappingProxyType[str, type[AsyncBackend]]: - backends = { - name: AsyncBackendFactory.__get_backend_cls(name, extended=extended) - for name in AsyncBackendFactory.__get_available_backends() - } - return MappingProxyType(backends) - - @staticmethod - def get_available_backends() -> frozenset[str]: - return frozenset(AsyncBackendFactory.__get_available_backends()) - - @staticmethod - def remove_all_extensions() -> None: - AsyncBackendFactory.__BACKEND_EXTENSIONS.clear() - - @staticmethod - def invalidate_backends_cache() -> None: - AsyncBackendFactory.remove_all_extensions() - AsyncBackendFactory.__load_backend_cls_from_entry_point.cache_clear() - AsyncBackendFactory.__get_available_backends.cache_clear() - - @staticmethod - def __get_backend_cls( - name: str, - error_msg_format: str = "Unknown backend {name!r}", - *, - extended: bool, - ) -> type[AsyncBackend]: - if extended: + __lock: Final[_lock.ForkSafeLock[threading.RLock]] = _lock.ForkSafeLock(threading.RLock) + __hooks: Final[deque[Callable[[str], AsyncBackend]]] = deque() + __instances: Final[dict[str, AsyncBackend]] = {} + + @classmethod + def current(cls) -> AsyncBackend: + name: str = _sniffio_helpers.current_async_library() + return cls.__get_backend(name, "Running library {name!r} misses the backend implementation") + + @classmethod + def get_backend(cls, name: str, /) -> AsyncBackend: + return cls.__get_backend(name, "Unknown backend {name!r}") + + @classmethod + def push_factory_hook(cls, factory: Callable[[str], AsyncBackend], /) -> None: + if not callable(factory): + raise TypeError(f"{factory!r} is not callable") + with cls.__lock.get(): + cls.__hooks.appendleft(factory) + + @classmethod + def push_backend_factory(cls, backend_name: str, factory: Callable[[], AsyncBackend]) -> None: + if not isinstance(backend_name, str): + raise TypeError("backend_name: Expected a string") + if backend_name.strip() != backend_name or not backend_name: + raise ValueError("backend_name: Invalid value") + if not callable(factory): + raise TypeError(f"{factory!r} is not callable") + return cls.push_factory_hook(functools.partial(cls.__backend_factory_hook, backend_name, factory)) + + @classmethod + def invalidate_backends_cache(cls) -> None: + with cls.__lock.get(): + cls.__instances.clear() + + @classmethod + def remove_installed_hooks(cls) -> None: + with cls.__lock.get(): + cls.__hooks.clear() + + @classmethod + def __get_backend(cls, name: str, error_msg_format: str) -> AsyncBackend: + with cls.__lock.get(): try: - return AsyncBackendFactory.__BACKEND_EXTENSIONS[name] + return cls.__instances[name] except KeyError: pass - try: - return AsyncBackendFactory.__load_backend_cls_from_entry_point(name) - except KeyError: - raise KeyError(error_msg_format.format(name=name)) from None - @staticmethod - @functools.cache - def __load_backend_cls_from_entry_point(name: str) -> type[AsyncBackend]: - entry_point: EntryPoint = AsyncBackendFactory.__get_available_backends()[name] - - entry_point_cls: Any = entry_point.load() - if ( - not isinstance(entry_point_cls, type) - or not issubclass(entry_point_cls, AsyncBackend) - or inspect.isabstract(entry_point_cls) - ): - raise TypeError(f"Invalid backend entry point (name={name!r}): {entry_point_cls!r}") - return entry_point_cls + backend_instance: AsyncBackend | None = None + for factory_hook in cls.__hooks: + try: + backend_instance = factory_hook(name) + except UnsupportedOperation: + continue - @staticmethod - @functools.cache - def __get_available_backends() -> MappingProxyType[str, EntryPoint]: - from importlib.metadata import EntryPoint, entry_points as get_all_entry_points + if not isinstance(backend_instance, AsyncBackend): + raise TypeError(f"{factory_hook!r} did not return an AsyncBackend instance") + break + + if backend_instance is None and name == "asyncio": + from ...std_asyncio import AsyncIOBackend - entry_points = get_all_entry_points(group=AsyncBackendFactory.GROUP_NAME) - duplicate_counter: Counter[str] = Counter([ep.name for ep in entry_points]) + backend_instance = AsyncIOBackend() - if duplicates := {name for name in duplicate_counter if duplicate_counter[name] > 1}: - raise TypeError(f"Conflicting backend name caught: {', '.join(map(repr, sorted(duplicates)))}") + if backend_instance is None: + raise NotImplementedError(error_msg_format.format(name=name)) - backends: dict[str, EntryPoint] = {ep.name: ep for ep in entry_points} + cls.__instances[name] = backend_instance + return backend_instance - if "asyncio" not in backends: - from importlib.util import resolve_name + @staticmethod + def __backend_factory_hook(backend_name: str, factory: Callable[[], AsyncBackend], name: str, /) -> AsyncBackend: + if name != backend_name: + raise UnsupportedOperation(f"{name!r} backend is not implemented") + return factory() - backends["asyncio"] = EntryPoint( - name="asyncio", - value=f"{resolve_name('...std_asyncio', __package__)}:AsyncIOBackend", - group=AsyncBackendFactory.GROUP_NAME, - ) - return MappingProxyType(backends) +current_async_backend = AsyncBackendFactory.current diff --git a/src/easynetwork/lowlevel/api_async/backend/futures.py b/src/easynetwork/lowlevel/api_async/backend/futures.py index 8b76831e..7067d854 100644 --- a/src/easynetwork/lowlevel/api_async/backend/futures.py +++ b/src/easynetwork/lowlevel/api_async/backend/futures.py @@ -22,17 +22,15 @@ import contextvars import functools from collections import deque -from collections.abc import AsyncGenerator, Callable, Iterable, Mapping +from collections.abc import AsyncGenerator, Callable, Iterable from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar from . import _sniffio_helpers -from .factory import AsyncBackendFactory +from .factory import current_async_backend if TYPE_CHECKING: from types import TracebackType - from .abc import AsyncBackend - _P = ParamSpec("_P") _T = TypeVar("_T") @@ -65,22 +63,17 @@ async def main() -> None: results = [await t.join() for t in tasks] """ - __slots__ = ("__backend", "__executor", "__handle_contexts", "__weakref__") + __slots__ = ("__executor", "__handle_contexts", "__weakref__") def __init__( self, executor: concurrent.futures.Executor, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, *, handle_contexts: bool = False, ) -> None: """ Parameters: executor: The executor instance to wrap. - backend: the backend to use. Automatically determined otherwise. - backend_kwargs: Keyword arguments for backend instanciation. - Ignored if `backend` is already an :class:`.AsyncBackend` instance. handle_contexts: If :data:`True`, contexts (:class:`contextvars.Context`) are properly propagated to workers. Defaults to :data:`False` because not all executors support the use of contexts (e.g. :class:`concurrent.futures.ProcessPoolExecutor`). @@ -88,7 +81,6 @@ def __init__( if not isinstance(executor, concurrent.futures.Executor): raise TypeError("Invalid executor type") - self.__backend: AsyncBackend = AsyncBackendFactory.ensure(backend, backend_kwargs) self.__executor: concurrent.futures.Executor = executor self.__handle_contexts: bool = bool(handle_contexts) @@ -132,8 +124,7 @@ async def run(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwar """ func = self._setup_func(func) executor = self.__executor - backend = self.__backend - return await _result_or_cancel(backend, executor.submit(func, *args, **kwargs)) + return await _result_or_cancel(executor.submit(func, *args, **kwargs)) def map(self, func: Callable[..., _T], *iterables: Iterable[Any]) -> AsyncGenerator[_T, None]: """ @@ -158,14 +149,13 @@ def pow_50(x): An asynchronous iterator equivalent to ``map(func, *iterables)`` but the calls may be evaluated out-of-order. """ - backend = self.__backend executor = self.__executor fs = deque(executor.submit(self._setup_func(func), *args) for args in zip(*iterables)) async def result_iterator() -> AsyncGenerator[_T, None]: try: while fs: - yield await _result_or_cancel(backend, fs.popleft()) + yield await _result_or_cancel(fs.popleft()) finally: for future in fs: future.cancel() @@ -201,7 +191,7 @@ async def shutdown(self, *, cancel_futures: bool = False) -> None: has not started running. Any futures that are completed or running won't be cancelled, regardless of the value of `cancel_futures`. """ - await self.__backend.run_in_thread(self.__executor.shutdown, wait=True, cancel_futures=cancel_futures) + await current_async_backend().run_in_thread(self.__executor.shutdown, wait=True, cancel_futures=cancel_futures) def _setup_func(self, func: Callable[_P, _T]) -> Callable[_P, _T]: if self.__handle_contexts: @@ -211,10 +201,10 @@ def _setup_func(self, func: Callable[_P, _T]) -> Callable[_P, _T]: return func -async def _result_or_cancel(backend: AsyncBackend, future: concurrent.futures.Future[_T]) -> _T: +async def _result_or_cancel(future: concurrent.futures.Future[_T]) -> _T: try: try: - return await backend.wait_future(future) + return await current_async_backend().wait_future(future) finally: future.cancel() finally: diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index 4c6f028f..62ceb467 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -32,6 +32,7 @@ from ....exceptions import DatagramProtocolParseError from ... import _asyncgen, _utils, typed_attr from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup +from ..backend.factory import current_async_backend from ..transports import abc as transports _T_Address = TypeVar("_T_Address", bound=Hashable) @@ -43,7 +44,6 @@ class AsyncDatagramServer(typed_attr.TypedAttributeProvider, Generic[_RequestT, __slots__ = ( "__listener", "__protocol", - "__backend", "__client_manager", "__sendto_lock", "__serve_guard", @@ -54,24 +54,16 @@ def __init__( self, listener: transports.AsyncDatagramListener[_T_Address], protocol: protocol_module.DatagramProtocol[_ResponseT, _RequestT], - *, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, ) -> None: if not isinstance(listener, transports.AsyncDatagramListener): raise TypeError(f"Expected an AsyncDatagramListener object, got {listener!r}") if not isinstance(protocol, protocol_module.DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") - from ..backend.factory import AsyncBackendFactory - - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) - self.__listener: transports.AsyncDatagramListener[_T_Address] = listener self.__protocol: protocol_module.DatagramProtocol[_ResponseT, _RequestT] = protocol - self.__backend: AsyncBackend = backend - self.__client_manager: _ClientManager[_T_Address] = _ClientManager(self.__backend) - self.__sendto_lock: ILock = self.__backend.create_lock() + self.__client_manager: _ClientManager[_T_Address] = _ClientManager(current_async_backend()) + self.__sendto_lock: ILock = current_async_backend().create_lock() self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receiving datagrams") def is_closing(self) -> bool: @@ -124,7 +116,7 @@ async def serve( with self.__serve_guard: client_coroutine = self.__client_coroutine client_manager = self.__client_manager - backend = self.__backend + backend = current_async_backend() listener = self.__listener async def handler(datagram: bytes, address: _T_Address, /) -> None: @@ -159,12 +151,6 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None: del datagram, address await backend.cancel_shielded_coro_yield() - def get_backend(self) -> AsyncBackend: - """ - Return the underlying backend interface. - """ - return self.__backend - async def __client_coroutine( self, datagram_received_cb: Callable[[_T_Address, Self], AsyncGenerator[None, _RequestT]], diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index a1a07dff..44212684 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -26,7 +26,7 @@ from ...._typevars import _RequestT, _ResponseT from ....exceptions import UnsupportedOperation from ... import _asyncgen, _stream, _utils, typed_attr -from ..backend.abc import AsyncBackend, TaskGroup +from ..backend.abc import TaskGroup from ..transports import abc as transports, utils as transports_utils @@ -90,7 +90,6 @@ class AsyncStreamServer(typed_attr.TypedAttributeProvider, Generic[_RequestT, _R "__listener", "__protocol", "__max_recv_size", - "__backend", "__serve_guard", "__weakref__", ) @@ -100,9 +99,6 @@ def __init__( listener: transports.AsyncListener[transports.AsyncStreamTransport], protocol: protocol_module.StreamProtocol[_ResponseT, _RequestT], max_recv_size: int, - *, - backend: str | AsyncBackend | None = None, - backend_kwargs: Mapping[str, Any] | None = None, ) -> None: if not isinstance(listener, transports.AsyncListener): raise TypeError(f"Expected an AsyncListener object, got {listener!r}") @@ -111,14 +107,9 @@ def __init__( if not isinstance(max_recv_size, int) or max_recv_size <= 0: raise ValueError("'max_recv_size' must be a strictly positive integer") - from ..backend.factory import AsyncBackendFactory - - backend = AsyncBackendFactory.ensure(backend, backend_kwargs) - self.__listener: transports.AsyncListener[transports.AsyncStreamTransport] = listener self.__protocol: protocol_module.StreamProtocol[_ResponseT, _RequestT] = protocol self.__max_recv_size: int = max_recv_size - self.__backend: AsyncBackend = backend self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently accepting new connections") def is_closing(self) -> bool: @@ -145,12 +136,6 @@ async def serve( handler = _utils.prepend_argument(client_connected_cb)(self.__client_coroutine) await self.__listener.serve(handler, task_group) - def get_backend(self) -> AsyncBackend: - """ - Return the underlying backend interface. - """ - return self.__backend - async def __client_coroutine( self, client_connected_cb: Callable[[AsyncStreamClient[_ResponseT]], AsyncGenerator[None, _RequestT]], @@ -160,7 +145,7 @@ async def __client_coroutine( raise TypeError(f"Expected an AsyncStreamTransport object, got {transport!r}") async with contextlib.AsyncExitStack() as client_exit_stack: - client_exit_stack.push_async_callback(transports_utils.aclose_forcefully, self.__backend, transport) + client_exit_stack.push_async_callback(transports_utils.aclose_forcefully, transport) producer = _stream.StreamDataProducer(self.__protocol) consumer: _stream.StreamDataConsumer[_RequestT] | _stream.BufferedStreamDataConsumer[_RequestT] diff --git a/src/easynetwork/lowlevel/api_async/transports/utils.py b/src/easynetwork/lowlevel/api_async/transports/utils.py index 5d7de76c..260a5e7f 100644 --- a/src/easynetwork/lowlevel/api_async/transports/utils.py +++ b/src/easynetwork/lowlevel/api_async/transports/utils.py @@ -18,21 +18,16 @@ __all__ = ["aclose_forcefully"] -from typing import TYPE_CHECKING - +from ..backend.factory import current_async_backend from .abc import AsyncBaseTransport -if TYPE_CHECKING: - from ..backend.abc import AsyncBackend - -async def aclose_forcefully(backend: AsyncBackend, transport: AsyncBaseTransport) -> None: +async def aclose_forcefully(transport: AsyncBaseTransport) -> None: """ Close an async resource or async generator immediately, without blocking to do any graceful cleanup. Parameters: - backend: the backend to use. transport: the transport to close. """ - with backend.move_on_after(0): + with current_async_backend().move_on_after(0): await transport.aclose() diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index ff7b79f4..04719fa0 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -26,7 +26,7 @@ class TestAsyncioBackend: @pytest.fixture @staticmethod def backend() -> AsyncIOBackend: - backend = AsyncBackendFactory.new("asyncio") + backend = AsyncBackendFactory.get_backend("asyncio") assert isinstance(backend, AsyncIOBackend) return backend @@ -1051,7 +1051,7 @@ class TestAsyncioBackendShieldedCancellation: @pytest.fixture @staticmethod def backend() -> AsyncIOBackend: - backend = AsyncBackendFactory.new("asyncio") + backend = AsyncBackendFactory.get_backend("asyncio") assert isinstance(backend, AsyncIOBackend) return backend diff --git a/tests/functional_test/test_async/test_backend/test_backend_factory.py b/tests/functional_test/test_async/test_backend/test_backend_factory.py deleted file mode 100644 index dbb179f3..00000000 --- a/tests/functional_test/test_async/test_backend/test_backend_factory.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator - -from easynetwork.lowlevel.api_async.backend.factory import AsyncBackendFactory - -import pytest - - -class TestAsyncBackendFactory: - @pytest.fixture(autouse=True) - @staticmethod - def setup_factory() -> Iterator[None]: - AsyncBackendFactory.invalidate_backends_cache() - yield - AsyncBackendFactory.set_default_backend(None) - - def test____get_available_backends____imports_asyncio_backend(self) -> None: - assert AsyncBackendFactory.get_available_backends() == frozenset({"asyncio"}) - - def test____get_all_backends____imports_asyncio_backend(self) -> None: - from easynetwork.lowlevel.std_asyncio import AsyncIOBackend - - assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncIOBackend} - - def test____get_default_backend____returns_asyncio_backend(self) -> None: - from easynetwork.lowlevel.std_asyncio import AsyncIOBackend - - assert AsyncBackendFactory.get_default_backend(guess_current_async_library=False) is AsyncIOBackend - - def test____new____returns_asyncio_backend_instance(self) -> None: - from easynetwork.lowlevel.std_asyncio import AsyncIOBackend - - backend = AsyncBackendFactory.new("asyncio") - - assert isinstance(backend, AsyncIOBackend) diff --git a/tests/functional_test/test_communication/test_async/conftest.py b/tests/functional_test/test_communication/test_async/conftest.py index 5e10ea54..da74cd96 100644 --- a/tests/functional_test/test_communication/test_async/conftest.py +++ b/tests/functional_test/test_communication/test_async/conftest.py @@ -1,7 +1,13 @@ from __future__ import annotations +from collections.abc import Iterator + +from easynetwork.lowlevel.std_asyncio import AsyncIOBackend + import pytest +from ....tools import temporary_backend + use_asyncio_transport_xfail_uvloop = pytest.mark.parametrize( "use_asyncio_transport", [ @@ -14,5 +20,9 @@ @pytest.fixture(params=[False, True], ids=lambda boolean: f"use_asyncio_transport=={boolean}") -def use_asyncio_transport(request: pytest.FixtureRequest) -> bool: - return getattr(request, "param") +def use_asyncio_transport(request: pytest.FixtureRequest) -> Iterator[bool]: + use_asyncio_transport: bool = getattr(request, "param") + + # TODO: Do not use temporary_backend() when env variable will be implemented + with temporary_backend(AsyncIOBackend(transport=use_asyncio_transport)): + yield use_asyncio_transport diff --git a/tests/functional_test/test_communication/test_async/test_client/test_tcp.py b/tests/functional_test/test_communication/test_async/test_client/test_tcp.py index 5f09c86f..172a9eec 100644 --- a/tests/functional_test/test_communication/test_async/test_client/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_client/test_tcp.py @@ -5,7 +5,6 @@ import ssl from collections.abc import AsyncIterator from socket import AF_INET, IPPROTO_TCP, SHUT_WR, TCP_NODELAY, socket as Socket -from typing import Any from easynetwork.api_async.client.tcp import AsyncTCPNetworkClient from easynetwork.exceptions import ClientClosedError, StreamProtocolParseError @@ -31,7 +30,7 @@ async def readline(loop: asyncio.AbstractEventLoop, sock: Socket) -> bytes: @pytest.mark.asyncio -@pytest.mark.usefixtures("simulate_no_ssl_module") +@pytest.mark.usefixtures("simulate_no_ssl_module", "use_asyncio_transport") class TestAsyncTCPNetworkClient: @pytest.fixture @staticmethod @@ -46,13 +45,8 @@ def server(socket_pair: tuple[Socket, Socket]) -> Socket: async def client( socket_pair: tuple[Socket, Socket], stream_protocol: StreamProtocol[str, str], - use_asyncio_transport: bool, ) -> AsyncIterator[AsyncTCPNetworkClient[str, str]]: - async with AsyncTCPNetworkClient( - socket_pair[1], - stream_protocol, - backend_kwargs={"transport": use_asyncio_transport}, - ) as client: + async with AsyncTCPNetworkClient(socket_pair[1], stream_protocol) as client: assert client.is_connected() yield client @@ -339,7 +333,7 @@ async def test____get_remote_address____consistency( @pytest.mark.asyncio -@pytest.mark.usefixtures("simulate_no_ssl_module") +@pytest.mark.usefixtures("simulate_no_ssl_module", "use_asyncio_transport") class TestAsyncTCPNetworkClientConnection: @pytest_asyncio.fixture(autouse=True) @staticmethod @@ -367,18 +361,12 @@ async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.Stre def remote_address(server: asyncio.Server) -> tuple[str, int]: return server.sockets[0].getsockname()[:2] - @pytest.fixture - @staticmethod - def backend_kwargs(use_asyncio_transport: bool) -> dict[str, Any]: - return {"transport": use_asyncio_transport} - async def test____dunder_init____connect_to_server( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with AsyncTCPNetworkClient(remote_address, stream_protocol, backend_kwargs=backend_kwargs) as client: + async with AsyncTCPNetworkClient(remote_address, stream_protocol) as client: assert client.is_connected() await client.send_packet("Test") assert await client.recv_packet() == "Test" @@ -388,14 +376,8 @@ async def test____dunder_init____with_local_address( localhost_ip: str, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with AsyncTCPNetworkClient( - remote_address, - stream_protocol, - local_address=(localhost_ip, 0), - backend_kwargs=backend_kwargs, - ) as client: + async with AsyncTCPNetworkClient(remote_address, stream_protocol, local_address=(localhost_ip, 0)) as client: await client.send_packet("Test") assert await client.recv_packet() == "Test" @@ -403,14 +385,9 @@ async def test____wait_connected____idempotent( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ), + AsyncTCPNetworkClient(remote_address, stream_protocol), ) as client: assert not client.is_connected() await client.wait_connected() @@ -422,15 +399,8 @@ async def test____wait_connected____simultaneous( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ), - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: await asyncio.gather(*[client.wait_connected() for _ in range(5)]) assert client.is_connected() @@ -438,15 +408,8 @@ async def test____wait_connected____is_closing____connection_not_performed_yet( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: assert not client.is_connected() assert not client.is_closing() await client.wait_connected() @@ -457,15 +420,8 @@ async def test____wait_connected____close_before_trying_to_connect( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: await client.aclose() with pytest.raises(ClientClosedError): await client.wait_connected() @@ -474,15 +430,8 @@ async def test____socket_property____connection_not_performed_yet( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: with pytest.raises(AttributeError): _ = client.socket @@ -494,15 +443,8 @@ async def test____get_local_address____connection_not_performed_yet( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: with pytest.raises(OSError): _ = client.get_local_address() @@ -514,15 +456,8 @@ async def test____get_remote_address____connection_not_performed_yet( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: with pytest.raises(OSError): _ = client.get_remote_address() @@ -534,15 +469,8 @@ async def test____send_packet____recv_packet____implicit_connection( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncTCPNetworkClient( - remote_address, - stream_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncTCPNetworkClient(remote_address, stream_protocol)) as client: assert not client.is_connected() await client.send_packet("Connected") @@ -580,16 +508,10 @@ async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.Stre def remote_address(server: asyncio.Server) -> tuple[str, int]: return server.sockets[0].getsockname()[:2] - @pytest.fixture - @staticmethod - def backend_kwargs() -> dict[str, Any]: - return {} - async def test____dunder_init____handshake_and_shutdown( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], client_ssl_context: ssl.SSLContext, ) -> None: # Arrange @@ -600,7 +522,6 @@ async def test____dunder_init____handshake_and_shutdown( stream_protocol, ssl=client_ssl_context, server_hostname="test.example.com", - backend_kwargs=backend_kwargs, ) as client: await client.send_packet("Test") assert await client.recv_packet() == "Test" @@ -610,7 +531,6 @@ async def test____dunder_init____use_default_context( remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], client_ssl_context: ssl.SSLContext, - backend_kwargs: dict[str, Any], monkeypatch: pytest.MonkeyPatch, ) -> None: # Arrange @@ -622,7 +542,6 @@ async def test____dunder_init____use_default_context( stream_protocol, ssl=True, server_hostname="test.example.com", - backend_kwargs=backend_kwargs, ) as client: await client.send_packet("Test") assert await client.recv_packet() == "Test" @@ -632,7 +551,6 @@ async def test____dunder_init____use_default_context____disable_hostname_check( remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], client_ssl_context: ssl.SSLContext, - backend_kwargs: dict[str, Any], monkeypatch: pytest.MonkeyPatch, ) -> None: # Arrange @@ -646,7 +564,6 @@ async def test____dunder_init____use_default_context____disable_hostname_check( stream_protocol, ssl=True, server_hostname="", - backend_kwargs=backend_kwargs, ) as client: assert not client_ssl_context.check_hostname # It must be set to False if server_hostname is an empty string await client.send_packet("Test") @@ -657,7 +574,6 @@ async def test____dunder_init____no_ssl_module_available( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], client_ssl_context: ssl.SSLContext, ) -> None: # Arrange @@ -669,14 +585,12 @@ async def test____dunder_init____no_ssl_module_available( stream_protocol, ssl=client_ssl_context, server_hostname="test.example.com", - backend_kwargs=backend_kwargs, ) async def test____send_eof____not_supported( self, remote_address: tuple[str, int], stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], client_ssl_context: ssl.SSLContext, ) -> None: # Arrange @@ -687,7 +601,6 @@ async def test____send_eof____not_supported( stream_protocol, ssl=client_ssl_context, server_hostname="test.example.com", - backend_kwargs=backend_kwargs, ) as client: with pytest.raises(NotImplementedError): await client.send_eof() diff --git a/tests/functional_test/test_communication/test_async/test_client/test_udp.py b/tests/functional_test/test_communication/test_async/test_client/test_udp.py index ef4b2551..bf52f69d 100644 --- a/tests/functional_test/test_communication/test_async/test_client/test_udp.py +++ b/tests/functional_test/test_communication/test_async/test_client/test_udp.py @@ -4,7 +4,6 @@ import contextlib from collections.abc import AsyncIterator, Awaitable, Callable from socket import AF_INET, socket as Socket -from typing import Any from easynetwork.api_async.client.udp import AsyncUDPNetworkClient from easynetwork.exceptions import ClientClosedError, DatagramProtocolParseError @@ -55,6 +54,7 @@ async def factory() -> DatagramEndpoint: @pytest.mark.asyncio +@pytest.mark.usefixtures("use_asyncio_transport") class TestAsyncUDPNetworkClient: @pytest_asyncio.fixture @staticmethod @@ -80,15 +80,12 @@ async def client( remote_address: tuple[str, int], use_external_socket: Socket | None, datagram_protocol: DatagramProtocol[str, str], - use_asyncio_transport: bool, ) -> AsyncIterator[AsyncUDPNetworkClient[str, str]]: if use_external_socket is not None: use_external_socket.connect(remote_address) - client = AsyncUDPNetworkClient( - use_external_socket, datagram_protocol, backend_kwargs={"transport": use_asyncio_transport} - ) + client = AsyncUDPNetworkClient(use_external_socket, datagram_protocol) else: - client = AsyncUDPNetworkClient(remote_address, datagram_protocol, backend_kwargs={"transport": use_asyncio_transport}) + client = AsyncUDPNetworkClient(remote_address, datagram_protocol) async with client: assert client.is_connected() yield client @@ -262,6 +259,7 @@ async def test____get_remote_address____consistency( @pytest.mark.asyncio +@pytest.mark.usefixtures("use_asyncio_transport") class TestAsyncUDPNetworkClientConnection: class EchoProtocol(asyncio.DatagramProtocol): transport: asyncio.DatagramTransport | None = None @@ -306,24 +304,12 @@ async def server( def remote_address(server: asyncio.DatagramTransport) -> tuple[str, int]: return server.get_extra_info("sockname")[:2] - @pytest.fixture - @staticmethod - def backend_kwargs(use_asyncio_transport: bool) -> dict[str, Any]: - return {"transport": use_asyncio_transport} - async def test____wait_connected____idempotent( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: await client.wait_connected() assert client.is_connected() await client.wait_connected() @@ -333,15 +319,8 @@ async def test____wait_connected____simultaneous( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: await asyncio.gather(*[client.wait_connected() for _ in range(5)]) assert client.is_connected() @@ -349,15 +328,8 @@ async def test____wait_connected____is_closing____connection_not_performed_yet( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: assert not client.is_connected() assert not client.is_closing() await client.wait_connected() @@ -368,15 +340,8 @@ async def test____wait_connected____close_before_trying_to_connect( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: await client.aclose() with pytest.raises(ClientClosedError): await client.wait_connected() @@ -385,15 +350,8 @@ async def test____socket_property____connection_not_performed_yet( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: with pytest.raises(AttributeError): _ = client.socket @@ -405,15 +363,8 @@ async def test____get_local_address____connection_not_performed_yet( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: with pytest.raises(OSError): _ = client.get_local_address() @@ -425,15 +376,8 @@ async def test____get_remote_address____connection_not_performed_yet( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: with pytest.raises(OSError): _ = client.get_remote_address() @@ -446,15 +390,8 @@ async def test____send_packet____recv_packet____implicit_connection( self, remote_address: tuple[str, int], datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], ) -> None: - async with contextlib.aclosing( - AsyncUDPNetworkClient( - remote_address, - datagram_protocol, - backend_kwargs=backend_kwargs, - ) - ) as client: + async with contextlib.aclosing(AsyncUDPNetworkClient(remote_address, datagram_protocol)) as client: assert not client.is_connected() async with asyncio.timeout(3): diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 0429ce35..b0f222f4 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -20,6 +20,7 @@ StreamProtocolParseError, ) from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend +from easynetwork.lowlevel.api_async.backend.factory import current_async_backend from easynetwork.lowlevel.socket import SocketAddress, enable_socket_linger from easynetwork.lowlevel.std_asyncio._asyncio_utils import create_connection from easynetwork.lowlevel.std_asyncio.backend import AsyncIOBackend @@ -30,6 +31,7 @@ import pytest import pytest_asyncio +from .....tools import temporary_backend from .base import BaseTestAsyncServer @@ -100,7 +102,7 @@ async def on_connection(self, client: AsyncStreamClient[str]) -> None: if self.milk_handshake: await client.send_packet("milk") if self.close_all_clients_on_connection: - await self.backend.sleep(0.1) + await current_async_backend().sleep(0.1) await client.aclose() async def on_disconnection(self, client: AsyncStreamClient[str]) -> None: @@ -164,17 +166,13 @@ async def handle_bad_requests(self, client: AsyncStreamClient[str]) -> AsyncIter self.bad_request_received[client_address(client)].append(exc) await client.send_packet("wrong encoding man.") - @property - def backend(self) -> AsyncBackend: - return self.server.get_backend() - class TimeoutRequestHandler(AsyncStreamRequestHandler[str, str]): request_timeout: float = 1.0 timeout_on_second_yield: bool = False async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: AsyncTCPNetworkServer[str, str]) -> None: - self.backend = server.get_backend() + self.backend = current_async_backend() async def on_connection(self, client: AsyncStreamClient[str]) -> None: await client.send_packet("milk") @@ -206,7 +204,7 @@ class InitialHandshakeRequestHandler(AsyncStreamRequestHandler[str, str]): bypass_handshake: bool = False async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: AsyncTCPNetworkServer[str, str]) -> None: - self.backend = server.get_backend() + self.backend = current_async_backend() async def on_connection(self, client: AsyncStreamClient[str]) -> AsyncGenerator[None, str]: await client.send_packet("milk") @@ -299,10 +297,10 @@ def use_ssl(request: Any) -> bool: @pytest.fixture @staticmethod - def backend_kwargs(use_asyncio_transport: bool, use_ssl: bool) -> dict[str, Any]: + def use_asyncio_transport(use_asyncio_transport: bool, use_ssl: bool) -> bool: if use_ssl and not use_asyncio_transport: pytest.skip("SSL/TLS not supported with transport=False") - return {"transport": use_asyncio_transport} + return use_asyncio_transport @pytest.fixture @staticmethod @@ -324,7 +322,7 @@ async def server( use_ssl: bool, server_ssl_context: ssl.SSLContext, ssl_handshake_timeout: float | None, - backend_kwargs: dict[str, Any], + use_asyncio_transport: bool, # Only here for dependency ) -> AsyncIterator[MyAsyncTCPServer]: async with MyAsyncTCPServer( localhost_ip, @@ -334,7 +332,6 @@ async def server( backlog=1, ssl=server_ssl_context if use_ssl else None, ssl_handshake_timeout=ssl_handshake_timeout, - backend_kwargs=backend_kwargs, ) as server: assert not server.sockets assert not server.get_addresses() @@ -404,13 +401,13 @@ async def _wait_client_disconnected(writer: asyncio.StreamWriter, request_handle @pytest.mark.parametrize("host", [None, ""], ids=repr) @pytest.mark.parametrize("log_client_connection", [True, False], ids=lambda p: f"log_client_connection=={p}") @pytest.mark.parametrize("use_ssl", ["NO_SSL"], indirect=True) + @pytest.mark.usefixtures("use_asyncio_transport") async def test____dunder_init____bind_to_all_available_interfaces( self, host: str | None, log_client_connection: bool, request_handler: MyAsyncTCPRequestHandler, stream_protocol: StreamProtocol[str, str], - backend_kwargs: dict[str, Any], caplog: pytest.LogCaptureFixture, ) -> None: async with MyAsyncTCPServer( @@ -418,7 +415,6 @@ async def test____dunder_init____bind_to_all_available_interfaces( 0, stream_protocol, request_handler, - backend_kwargs=backend_kwargs, log_client_connection=log_client_connection, ) as s: caplog.set_level(logging.DEBUG, s.logger.name) @@ -488,11 +484,12 @@ async def test____serve_forever____empty_listener_list( request_handler: MyAsyncTCPRequestHandler, stream_protocol: StreamProtocol[str, str], ) -> None: - async with MyAsyncTCPServer(None, 0, stream_protocol, request_handler, backend=NoListenerErrorBackend()) as s: - with pytest.raises(OSError, match=r"^empty listeners list$"): - await s.serve_forever() + with temporary_backend(NoListenerErrorBackend()): + async with MyAsyncTCPServer(None, 0, stream_protocol, request_handler) as s: + with pytest.raises(OSError, match=r"^empty listeners list$"): + await s.serve_forever() - assert not s.sockets + assert not s.sockets @pytest.mark.usefixtures("run_server_and_wait") async def test____serve_forever____server_assignment( diff --git a/tests/functional_test/test_communication/test_async/test_server/test_udp.py b/tests/functional_test/test_communication/test_async/test_server/test_udp.py index f245838c..285c13c0 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_udp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_udp.py @@ -24,6 +24,7 @@ import pytest_asyncio from .....pytest_plugins.asyncio_event_loop import EventLoop +from .....tools import temporary_backend from .base import BaseTestAsyncServer @@ -206,11 +207,11 @@ class MyAsyncUDPServer(AsyncUDPNetworkServer[str, str]): class TestAsyncUDPNetworkServer(BaseTestAsyncServer): @pytest.fixture @staticmethod - def backend_kwargs(event_loop_name: EventLoop, use_asyncio_transport: bool) -> dict[str, Any]: + def use_asyncio_transport(event_loop_name: EventLoop, use_asyncio_transport: bool) -> bool: if not use_asyncio_transport: if event_loop_name == EventLoop.UVLOOP: pytest.xfail("uvloop runner does not implement the needed functions") - return {"transport": use_asyncio_transport} + return use_asyncio_transport @pytest.fixture @staticmethod @@ -224,15 +225,9 @@ async def server( request_handler: AsyncDatagramRequestHandler[str, str], localhost_ip: str, datagram_protocol: DatagramProtocol[str, str], - backend_kwargs: dict[str, Any], + use_asyncio_transport: bool, # Only here for dependency ) -> AsyncIterator[MyAsyncUDPServer]: - async with MyAsyncUDPServer( - localhost_ip, - 0, - datagram_protocol, - request_handler, - backend_kwargs=backend_kwargs, - ) as server: + async with MyAsyncUDPServer(localhost_ip, 0, datagram_protocol, request_handler) as server: assert not server.sockets assert not server.get_addresses() yield server @@ -277,11 +272,12 @@ async def test____serve_forever____empty_listener_list( request_handler: MyAsyncUDPRequestHandler, datagram_protocol: DatagramProtocol[str, str], ) -> None: - async with MyAsyncUDPServer(None, 0, datagram_protocol, request_handler, backend=NoListenerErrorBackend()) as s: - with pytest.raises(OSError, match=r"^empty listeners list$"): - await s.serve_forever() + with temporary_backend(NoListenerErrorBackend()): + async with MyAsyncUDPServer(None, 0, datagram_protocol, request_handler) as s: + with pytest.raises(OSError, match=r"^empty listeners list$"): + await s.serve_forever() - assert not s.sockets + assert not s.sockets @pytest.mark.usefixtures("run_server_and_wait") async def test____serve_forever____server_assignment( diff --git a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py index d8cc9522..9fce7147 100644 --- a/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py +++ b/tests/functional_test/test_communication/test_sync/test_server/test_standalone.py @@ -65,6 +65,9 @@ def test____server_close____idempotent(self, server: AbstractNetworkServer) -> N def test____server_close____while_server_is_running(self, server: AbstractNetworkServer) -> None: server.server_close() + with pytest.raises(ServerClosedError): + server.serve_forever() + @pytest.mark.usefixtures("start_server") def test____serve_forever____error_server_already_running(self, server: AbstractNetworkServer) -> None: with pytest.raises(ServerAlreadyRunning): @@ -110,16 +113,6 @@ def client(server: StandaloneTCPNetworkServer[str, str], start_server: None) -> with socket.create_connection(("localhost", port)) as client: yield client - def test____dunder_init____invalid_backend(self, stream_protocol: StreamProtocol[str, str]) -> None: - with pytest.raises(ValueError, match=r"^You must explicitly give a backend name or instance$"): - _ = StandaloneTCPNetworkServer( - None, - 0, - stream_protocol, - EchoRequestHandler(), - backend=None, # type: ignore[arg-type] - ) - def test____serve_forever____serve_several_times(self, server: StandaloneTCPNetworkServer[str, str]) -> None: with server: for _ in range(3): @@ -152,9 +145,6 @@ def test____stop_listening____stop_accepting_new_connection(self, server: Standa assert len(server.sockets) > 0 # Sockets are closed, but always available until server_close() call assert len(server.get_addresses()) == 0 - def test____logger_property____exposed(self, server: StandaloneTCPNetworkServer[str, str]) -> None: - assert server.logger is server._server.logger - class TestStandaloneUDPNetworkServer(BaseTestStandaloneNetworkServer): @pytest.fixture @@ -162,16 +152,6 @@ class TestStandaloneUDPNetworkServer(BaseTestStandaloneNetworkServer): def server(datagram_protocol: DatagramProtocol[str, str]) -> StandaloneUDPNetworkServer[str, str]: return StandaloneUDPNetworkServer("localhost", 0, datagram_protocol, EchoRequestHandler()) - def test____dunder_init____invalid_backend(self, datagram_protocol: DatagramProtocol[str, str]) -> None: - with pytest.raises(ValueError, match=r"^You must explicitly give a backend name or instance$"): - _ = StandaloneUDPNetworkServer( - "localhost", - 0, - datagram_protocol, - EchoRequestHandler(), - backend=None, # type: ignore[arg-type] - ) - def test____serve_forever____serve_several_times(self, server: StandaloneUDPNetworkServer[str, str]) -> None: with server: for _ in range(3): @@ -196,6 +176,3 @@ def test____socket_property____server_is_not_running(self, server: StandaloneUDP def test____socket_property____server_is_running(self, server: StandaloneUDPNetworkServer[str, str]) -> None: assert len(server.sockets) > 0 assert len(server.get_addresses()) > 0 - - def test____logger_property____exposed(self, server: StandaloneUDPNetworkServer[str, str]) -> None: - assert server.logger is server._server.logger diff --git a/tests/scripts/async_client_test.py b/tests/scripts/async_client_test.py index 070863df..67b499c6 100644 --- a/tests/scripts/async_client_test.py +++ b/tests/scripts/async_client_test.py @@ -13,7 +13,7 @@ def create_tcp_client(port: int) -> AsyncTCPNetworkClient[str, str]: - return AsyncTCPNetworkClient(("localhost", port), StreamProtocol(StringLineSerializer()), backend_kwargs={"transport": False}) + return AsyncTCPNetworkClient(("localhost", port), StreamProtocol(StringLineSerializer())) def create_udp_client(port: int) -> AsyncUDPNetworkClient[str, str]: diff --git a/tests/tools.py b/tests/tools.py index 5b4c1a74..d27e0d96 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -1,15 +1,20 @@ from __future__ import annotations import asyncio +import contextlib import importlib import sys import time -from collections.abc import Generator +from collections.abc import Generator, Iterator from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, assert_never, final +from easynetwork.lowlevel.api_async.backend.factory import AsyncBackendFactory + import pytest if TYPE_CHECKING: + from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend + from _typeshed import WriteableBuffer _T_contra = TypeVar("_T_contra", contravariant=True) @@ -135,3 +140,14 @@ def write_data_and_extra_in_buffer( too_short_buffer=too_short_buffer_for_extra_data, ) return complete_data_nbytes + extra_data_nbytes, extra_data[:extra_data_nbytes] + + +@contextlib.contextmanager +def temporary_backend(backend: AsyncBackend) -> Iterator[None]: + with contextlib.ExitStack() as stack: + stack.callback(AsyncBackendFactory.invalidate_backends_cache) + stack.callback(AsyncBackendFactory.remove_installed_hooks) + AsyncBackendFactory.invalidate_backends_cache() + AsyncBackendFactory.push_backend_factory("asyncio", lambda: backend) + assert AsyncBackendFactory.get_backend("asyncio") is backend + yield diff --git a/tests/unit_test/test_async/test_api/test_client/base.py b/tests/unit_test/test_async/test_api/test_client/base.py index f7ea6fe9..8955196c 100644 --- a/tests/unit_test/test_async/test_api/test_client/base.py +++ b/tests/unit_test/test_async/test_api/test_client/base.py @@ -1,7 +1,20 @@ from __future__ import annotations +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import pytest + +from .....tools import temporary_backend from ....base import BaseTestSocket +if TYPE_CHECKING: + from unittest.mock import MagicMock + class BaseTestClient(BaseTestSocket): - pass + @pytest.fixture(autouse=True) + @staticmethod + def mock_backend(mock_backend: MagicMock) -> Iterator[MagicMock]: + with temporary_backend(mock_backend): + yield mock_backend diff --git a/tests/unit_test/test_async/test_api/test_client/test_abc.py b/tests/unit_test/test_async/test_api/test_client/test_abc.py index 7d26e1fb..908c9902 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_abc.py +++ b/tests/unit_test/test_async/test_api/test_client/test_abc.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from collections.abc import Iterator from typing import TYPE_CHECKING, Any, final from easynetwork.api_async.client.abc import AbstractAsyncNetworkClient @@ -8,19 +9,18 @@ import pytest +from .....tools import temporary_backend + if TYPE_CHECKING: from unittest.mock import MagicMock - from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend - from pytest_mock import MockerFixture @final class MockAsyncClient(AbstractAsyncNetworkClient[Any, Any]): - def __init__(self, mock_backend: MagicMock, mocker: MockerFixture) -> None: + def __init__(self, mocker: MockerFixture) -> None: super().__init__() - self.mock_backend = mock_backend self.mock_wait_connected = mocker.AsyncMock(return_value=None) self.mock_close = mocker.AsyncMock(return_value=None) self.mock_recv_packet = mocker.AsyncMock() @@ -49,16 +49,19 @@ async def send_packet(self, packet: Any) -> None: async def recv_packet(self) -> Any: return await self.mock_recv_packet() - def get_backend(self) -> AsyncBackend: - return self.mock_backend - @pytest.mark.asyncio class TestAbstractAsyncNetworkClient: + @pytest.fixture(autouse=True) + @staticmethod + def mock_backend(mock_backend: MagicMock) -> Iterator[MagicMock]: + with temporary_backend(mock_backend): + yield mock_backend + @pytest.fixture @staticmethod - def client(mock_backend: MagicMock, mocker: MockerFixture) -> MockAsyncClient: - return MockAsyncClient(mock_backend, mocker) + def client(mocker: MockerFixture) -> MockAsyncClient: + return MockAsyncClient(mocker) async def test____context____close_client_at_end(self, client: MockAsyncClient) -> None: # Arrange diff --git a/tests/unit_test/test_async/test_api/test_client/test_tcp.py b/tests/unit_test/test_async/test_api/test_client/test_tcp.py index fe7aee14..e79ed544 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_tcp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_tcp.py @@ -66,13 +66,6 @@ def mock_stream_data_consumer(mocker: MockerFixture) -> MagicMock: def mock_stream_data_consumer_cls(mocker: MockerFixture, mock_stream_data_consumer: MagicMock) -> MagicMock: return mocker.patch("easynetwork.lowlevel._stream.StreamDataConsumer", return_value=mock_stream_data_consumer) - @pytest.fixture(autouse=True) - @staticmethod - def mock_new_backend(mocker: MockerFixture, mock_backend: MagicMock) -> MagicMock: - from easynetwork.lowlevel.api_async.backend.factory import AsyncBackendFactory - - return mocker.patch.object(AsyncBackendFactory, "new", return_value=mock_backend) - @pytest.fixture(autouse=True) @classmethod def local_address( @@ -156,18 +149,13 @@ def next_side_effect() -> Any: mock_stream_data_consumer.__iter__.side_effect = lambda: mock_stream_data_consumer mock_stream_data_consumer.__next__.side_effect = next_side_effect - @pytest.fixture + @pytest_asyncio.fixture @staticmethod - def client_not_connected( + async def client_not_connected( remote_address: tuple[str, int], - mock_backend: MagicMock, mock_stream_protocol: MagicMock, ) -> AsyncTCPNetworkClient[Any, Any]: - client: AsyncTCPNetworkClient[Any, Any] = AsyncTCPNetworkClient( - remote_address, - mock_stream_protocol, - backend=mock_backend, - ) + client: AsyncTCPNetworkClient[Any, Any] = AsyncTCPNetworkClient(remote_address, mock_stream_protocol) assert not client.is_connected() return client @@ -191,7 +179,6 @@ async def test____dunder_init____connect_to_remote( mock_backend: MagicMock, mock_stream_data_consumer_cls: MagicMock, mock_stream_protocol: MagicMock, - mock_new_backend: MagicMock, mocker: MockerFixture, ) -> None: # Arrange @@ -207,7 +194,6 @@ async def test____dunder_init____connect_to_remote( await client.wait_connected() # Assert - mock_new_backend.assert_called_once_with(None) mock_stream_data_consumer_cls.assert_called_once_with(mock_stream_protocol) mock_backend.create_tcp_connection.assert_awaited_once_with( *remote_address, @@ -220,51 +206,12 @@ async def test____dunder_init____connect_to_remote( ] assert isinstance(client.socket, SocketProxy) - async def test____dunder_init____backend____from_string( - self, - remote_address: tuple[str, int], - mock_stream_protocol: MagicMock, - mock_new_backend: MagicMock, - ) -> None: - # Arrange - - # Act - _ = AsyncTCPNetworkClient( - remote_address, - protocol=mock_stream_protocol, - backend="custom_backend", - backend_kwargs={"arg1": 1, "arg2": "2"}, - ) - - # Assert - mock_new_backend.assert_called_once_with("custom_backend", arg1=1, arg2="2") - - async def test____dunder_init____backend____explicit_argument( - self, - remote_address: tuple[str, int], - mock_stream_protocol: MagicMock, - mock_backend: MagicMock, - mock_new_backend: MagicMock, - ) -> None: - # Arrange - - # Act - _ = AsyncTCPNetworkClient( - remote_address, - protocol=mock_stream_protocol, - backend=mock_backend, - ) - - # Assert - mock_new_backend.assert_not_called() - async def test____dunder_init____use_given_socket( self, mock_tcp_socket: MagicMock, mock_backend: MagicMock, mock_stream_data_consumer_cls: MagicMock, mock_stream_protocol: MagicMock, - mock_new_backend: MagicMock, mocker: MockerFixture, ) -> None: # Arrange @@ -274,7 +221,6 @@ async def test____dunder_init____use_given_socket( await client.wait_connected() # Assert - mock_new_backend.assert_called_once_with(None) mock_stream_data_consumer_cls.assert_called_once_with(mock_stream_protocol) mock_backend.wrap_stream_socket.assert_awaited_once_with(mock_tcp_socket) assert mock_tcp_socket.mock_calls == [ @@ -414,7 +360,6 @@ async def test____dunder_init____ssl( mock_backend: MagicMock, mock_stream_data_consumer_cls: MagicMock, mock_stream_protocol: MagicMock, - mock_new_backend: MagicMock, mock_ssl_context: MagicMock, mock_ssl_create_default_context: MagicMock, mocker: MockerFixture, @@ -448,7 +393,6 @@ async def test____dunder_init____ssl( await client.wait_connected() # Assert - mock_new_backend.assert_called_once_with(None) mock_stream_data_consumer_cls.assert_called_once_with(mock_stream_protocol) mock_ssl_create_default_context.assert_not_called() if use_socket: @@ -1437,7 +1381,6 @@ async def test____special_case____separate_send_and_receive_locks____ssl( ssl_shared_lock: bool | None, mock_stream_socket_adapter: MagicMock, mock_stream_protocol: MagicMock, - mock_backend: MagicMock, mock_ssl_context: MagicMock, mocker: MockerFixture, ) -> None: @@ -1448,7 +1391,6 @@ async def test____special_case____separate_send_and_receive_locks____ssl( ssl=mock_ssl_context, server_hostname="server_hostname", ssl_shared_lock=ssl_shared_lock, - backend=mock_backend, ) await client.wait_connected() @@ -1475,13 +1417,3 @@ async def recv_side_effect(bufsize: int) -> bytes: mock_stream_protocol.generate_chunks.assert_called_with(mocker.sentinel.packet) mock_stream_socket_adapter.send_all_from_iterable.assert_called() mock_stream_socket_adapter.send_all.assert_called_with(b"packet\n") - - async def test____get_backend____default( - self, - client_connected_or_not: AsyncTCPNetworkClient[Any, Any], - mock_backend: MagicMock, - ) -> None: - # Arrange - - # Act & Assert - assert client_connected_or_not.get_backend() is mock_backend diff --git a/tests/unit_test/test_async/test_api/test_client/test_udp.py b/tests/unit_test/test_async/test_api/test_client/test_udp.py index da170dcb..0809110f 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_udp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_udp.py @@ -23,13 +23,6 @@ from .base import BaseTestClient -@pytest.fixture(autouse=True) -def mock_new_backend(mocker: MockerFixture, mock_backend: MagicMock) -> MagicMock: - from easynetwork.lowlevel.api_async.backend.factory import AsyncBackendFactory - - return mocker.patch.object(AsyncBackendFactory, "new", return_value=mock_backend) - - @pytest.mark.asyncio class TestAsyncUDPNetworkClient(BaseTestClient): @pytest.fixture(scope="class", params=["AF_INET", "AF_INET6"]) @@ -105,18 +98,13 @@ def build_packet_from_datagram_side_effect(data: bytes) -> Any: mock_datagram_protocol.make_datagram.side_effect = make_datagram_side_effect mock_datagram_protocol.build_packet_from_datagram.side_effect = build_packet_from_datagram_side_effect - @pytest.fixture + @pytest_asyncio.fixture @staticmethod - def client_not_connected( - mock_backend: MagicMock, + async def client_not_connected( mock_udp_socket: MagicMock, mock_datagram_protocol: MagicMock, ) -> AsyncUDPNetworkClient[Any, Any]: - client: AsyncUDPNetworkClient[Any, Any] = AsyncUDPNetworkClient( - mock_udp_socket, - mock_datagram_protocol, - backend=mock_backend, - ) + client: AsyncUDPNetworkClient[Any, Any] = AsyncUDPNetworkClient(mock_udp_socket, mock_datagram_protocol) assert not client.is_connected() return client @@ -138,7 +126,6 @@ async def test____dunder_init____with_remote_address( remote_address: tuple[str, int], mock_udp_socket: MagicMock, mock_datagram_protocol: MagicMock, - mock_new_backend: MagicMock, mock_backend: MagicMock, mocker: MockerFixture, ) -> None: @@ -153,7 +140,6 @@ async def test____dunder_init____with_remote_address( await client.wait_connected() # Assert - mock_new_backend.assert_called_once_with(None) mock_backend.create_udp_endpoint.assert_awaited_once_with( *remote_address, local_address=mocker.sentinel.local_address, @@ -255,49 +241,10 @@ async def test____dunder_init____with_remote_address____force_local_address( local_address=("localhost", 0), ) - async def test____dunder_init____backend____from_string( - self, - remote_address: tuple[str, int], - mock_datagram_protocol: MagicMock, - mock_new_backend: MagicMock, - ) -> None: - # Arrange - - # Act - _ = AsyncUDPNetworkClient( - remote_address, - mock_datagram_protocol, - backend="custom_backend", - backend_kwargs={"arg1": 1, "arg2": "2"}, - ) - - # Assert - mock_new_backend.assert_called_once_with("custom_backend", arg1=1, arg2="2") - - async def test____dunder_init____backend____explicit_argument( - self, - remote_address: tuple[str, int], - mock_datagram_protocol: MagicMock, - mock_backend: MagicMock, - mock_new_backend: MagicMock, - ) -> None: - # Arrange - - # Act - _ = AsyncUDPNetworkClient( - remote_address, - mock_datagram_protocol, - backend=mock_backend, - ) - - # Assert - mock_new_backend.assert_not_called() - async def test____dunder_init____use_given_socket( self, mock_udp_socket: MagicMock, mock_datagram_protocol: MagicMock, - mock_new_backend: MagicMock, mock_backend: MagicMock, mocker: MockerFixture, ) -> None: @@ -312,7 +259,6 @@ async def test____dunder_init____use_given_socket( # Assert mock_udp_socket.bind.assert_not_called() - mock_new_backend.assert_called_once_with(None) mock_backend.wrap_connected_datagram_socket.assert_awaited_once_with(mock_udp_socket) assert mock_udp_socket.mock_calls == [ mocker.call.getpeername(), @@ -773,13 +719,3 @@ async def test____recv_packet____convert_closed_socket_error( # Assert mock_datagram_socket_adapter.recv.assert_awaited_once() mock_datagram_protocol.build_packet_from_datagram.assert_not_called() - - async def test____get_backend____default( - self, - client_connected_or_not: AsyncUDPNetworkClient[Any, Any], - mock_backend: MagicMock, - ) -> None: - # Arrange - - # Act & Assert - assert client_connected_or_not.get_backend() is mock_backend diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py index 09f5cb82..c29404c8 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py @@ -6,6 +6,7 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, assert_never, final +from easynetwork.exceptions import UnsupportedOperation from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend from easynetwork.lowlevel.api_async.backend.factory import AsyncBackendFactory @@ -14,7 +15,6 @@ from ._fake_backends import BaseFakeBackend, FakeAsyncIOBackend, FakeCurioBackend, FakeTrioBackend if TYPE_CHECKING: - from importlib.metadata import EntryPoint from unittest.mock import MagicMock from pytest_mock import MockerFixture @@ -95,389 +95,365 @@ class TestAsyncBackendFactory: BACKEND_CLS_TO_NAME: MappingProxyType[type[AsyncBackend], str] = MappingProxyType({v: k for k, v in BACKENDS.items()}) + CURRENT_ASYNC_LIBRARY_FUNC: str = "easynetwork.lowlevel.api_async.backend._sniffio_helpers.current_async_library" + @pytest.fixture(scope="class", autouse=True) @staticmethod def reset_factory_cache_at_end() -> Iterator[None]: # Drop after all tests are done in order not to impact next tests yield AsyncBackendFactory.invalidate_backends_cache() - - from easynetwork.lowlevel.std_asyncio import AsyncIOBackend - - assert AsyncBackendFactory.get_all_backends() == {"asyncio": AsyncIOBackend} - - @pytest.fixture(autouse=True) - @staticmethod - def setup_factory() -> Iterator[None]: - AsyncBackendFactory.invalidate_backends_cache() - yield - AsyncBackendFactory.set_default_backend(None) - AsyncBackendFactory.remove_all_extensions() + AsyncBackendFactory.remove_installed_hooks() @pytest.fixture(autouse=True) @classmethod - def mock_importlib_metadata_entry_points(cls, mocker: MockerFixture) -> MagicMock: - return mocker.patch( - "importlib.metadata.entry_points", - autospec=True, - return_value=list(map(cls.build_entry_point, cls.BACKENDS)), - ) - - @classmethod - def build_entry_point( - cls, - name: str, - value: str = "", - ) -> EntryPoint: - from importlib.metadata import EntryPoint - - if not value: - try: - _default_cls = cls.BACKENDS[name] - except KeyError: - _default_cls = MockBackend - value = f"{_default_cls.__module__}:{_default_cls.__name__}" - - return EntryPoint(name, value, AsyncBackendFactory.GROUP_NAME) - - def test____get_all_backends____default(self) -> None: - # Arrange - - # Act - backends = AsyncBackendFactory.get_all_backends() - - # Assert - assert backends == self.BACKENDS - - def test____get_all_backends____entry_point_is_module(self, mock_importlib_metadata_entry_points: MagicMock) -> None: - # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio", __name__), - ] - - # Act & Assert - with pytest.raises(TypeError, match=r"^Invalid backend entry point \(name='asyncio'\): .+$"): - AsyncBackendFactory.get_all_backends() + def setup_factory(cls) -> None: + AsyncBackendFactory.invalidate_backends_cache() + AsyncBackendFactory.remove_installed_hooks() + for backend_name, backend_cls in cls.BACKENDS.items(): + AsyncBackendFactory.push_backend_factory(backend_name, backend_cls) - def test____get_all_backends____entry_point_is_not_async_backend_class( - self, - mock_importlib_metadata_entry_points: MagicMock, - ) -> None: + def test____push_factory_hook____not_callable(self, mocker: MockerFixture) -> None: # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio", "builtins:int"), - ] + obj = mocker.NonCallableMock() # Act & Assert - with pytest.raises(TypeError, match=r"^Invalid backend entry point \(name='asyncio'\): .+$"): - AsyncBackendFactory.get_all_backends() + with pytest.raises(TypeError, match=r"^.+ is not callable$"): + AsyncBackendFactory.push_factory_hook(obj) - def test____get_all_backends____entry_point_is_abstract( - self, - mock_importlib_metadata_entry_points: MagicMock, - ) -> None: + def test____push_backend_factory____not_callable(self, mocker: MockerFixture) -> None: # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio", "easynetwork.lowlevel.api_async.backend.abc:AsyncBackend"), - ] + obj = mocker.NonCallableMock() # Act & Assert - with pytest.raises(TypeError, match=r"^Invalid backend entry point \(name='asyncio'\): .+$"): - AsyncBackendFactory.get_all_backends() + with pytest.raises(TypeError, match=r"^.+ is not callable$"): + AsyncBackendFactory.push_backend_factory("mock", obj) - def test____get_all_backends____entry_point_module_not_found( - self, - mock_importlib_metadata_entry_points: MagicMock, - ) -> None: + def test____push_backend_factory____backend_is_not_a_string(self, mocker: MockerFixture) -> None: # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio", "unknown_module:Backend"), - ] + factory = mocker.stub() # Act & Assert - with pytest.raises(ModuleNotFoundError, match=r"^No module named 'unknown_module'$"): - AsyncBackendFactory.get_all_backends() + with pytest.raises(TypeError, match=r"^backend_name: Expected a string$"): + AsyncBackendFactory.push_backend_factory(4, factory) # type: ignore[arg-type] - def test____get_all_backends____duplicate( - self, - mock_importlib_metadata_entry_points: MagicMock, - ) -> None: + @pytest.mark.parametrize("invalid_token", ["", " ", "\nasyncio", "asyncio\t"], ids=repr) + def test____push_backend_factory____backend_string_is_invalid(self, invalid_token: str, mocker: MockerFixture) -> None: # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio"), - self.build_entry_point("asyncio"), - self.build_entry_point("asyncio"), - ] + factory = mocker.stub() # Act & Assert - with pytest.raises(TypeError, match=r"^Conflicting backend name caught: 'asyncio'$"): - AsyncBackendFactory.get_all_backends() + with pytest.raises(ValueError, match=r"^backend_name: Invalid value$"): + AsyncBackendFactory.push_backend_factory(invalid_token, factory) - def test____get_available_backends____default( + @pytest.mark.parametrize("backend_name", list(BACKENDS)) + def test____get_backend____returns_backend_instance_from_hook( self, - mock_importlib_metadata_entry_points: MagicMock, + backend_name: str, ) -> None: - # Arrange - mock_importlib_metadata_entry_points.return_value = [ - self.build_entry_point("asyncio"), - self.build_entry_point("trio"), - self.build_entry_point("curio"), - self.build_entry_point("mock", f"{__name__}:MockBackend"), - ] - - # Act - available_backends = AsyncBackendFactory.get_available_backends() - - # Assert - assert isinstance(available_backends, frozenset) - assert available_backends == frozenset({"asyncio", "trio", "curio", "mock"}) - - @pytest.mark.parametrize("backend_name", list(BACKENDS)) - @pytest.mark.parametrize("extended", [False, True], ids=lambda extended: f"extended=={extended}") - def test____set_default_backend____from_string(self, backend_name: str, extended: bool) -> None: # Arrange expected_cls = self.BACKENDS[backend_name] - if extended: - - class ExtendedBackend(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass - - AsyncBackendFactory.extend(backend_name, ExtendedBackend) - expected_cls = ExtendedBackend # Act - AsyncBackendFactory.set_default_backend(backend_name) + backend = AsyncBackendFactory.get_backend(backend_name) # Assert - assert AsyncBackendFactory.get_default_backend(guess_current_async_library=False) is expected_cls + assert isinstance(backend, AsyncBackend) + assert isinstance(backend, expected_cls) - def test____set_default_backend____from_string____unknown_backend(self) -> None: + def test____get_backend____unknown_backend(self) -> None: # Arrange # Act & Assert - with pytest.raises(KeyError): - AsyncBackendFactory.set_default_backend("unknown") + with pytest.raises(NotImplementedError, match=r"^Unknown backend 'mock'$"): + AsyncBackendFactory.get_backend("mock") - @pytest.mark.parametrize("backend_cls", [*BACKENDS.values(), MockBackend]) - @pytest.mark.parametrize("extended", [False, True], ids=lambda extended: f"extended=={extended}") - def test____set_default_backend____from_class(self, backend_cls: type[AsyncBackend], extended: bool) -> None: + @pytest.mark.parametrize("backend_name", list(BACKENDS)) + def test____get_backend____by_default(self, backend_name: str) -> None: # Arrange - if extended: - try: - _backend_name = self.BACKEND_CLS_TO_NAME[backend_cls] - except KeyError: - pytest.skip("Not an entry-point") + AsyncBackendFactory.remove_installed_hooks() - class ExtendedBackend(backend_cls): # type: ignore[valid-type,misc] - pass + # Act & Assert + match backend_name: + case "asyncio": + from easynetwork.lowlevel.std_asyncio import AsyncIOBackend + + assert type(AsyncBackendFactory.get_backend("asyncio")) is AsyncIOBackend + case _: + with pytest.raises(NotImplementedError, match=r"^Unknown backend '.+'$"): + AsyncBackendFactory.get_backend(backend_name) - AsyncBackendFactory.extend(_backend_name, ExtendedBackend) + @pytest.mark.parametrize("backend_name", list(BACKENDS)) + def test____get_backend____singleton(self, backend_name: str) -> None: + # Arrange + first_backend_instance = AsyncBackendFactory.get_backend(backend_name) # Act - AsyncBackendFactory.set_default_backend(backend_cls) + this_backend = AsyncBackendFactory.get_backend(backend_name) # Assert - assert AsyncBackendFactory.get_default_backend(guess_current_async_library=False) is backend_cls + assert this_backend is first_backend_instance - @pytest.mark.parametrize("invalid_cls", [int, Socket, TestAsyncBackend]) - def test____set_default_backend____from_class____error_do_not_derive_from_AsyncBackend( + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) + def test____get_backend____add_hook( self, - invalid_cls: type[Any], + action: Literal["push_backend_factory", "push_factory_hook"], + mocker: MockerFixture, ) -> None: # Arrange + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory("mock", lambda: MockBackend(mocker)) + case "push_factory_hook": - # Act & Assert - with pytest.raises(TypeError, match=rf"^Invalid backend class: {invalid_cls!r}$"): - AsyncBackendFactory.set_default_backend(invalid_cls) - - def test____set_default_backend____from_class____error_abstract_class_given(self) -> None: - # Arrange - - # Act & Assert - with pytest.raises(TypeError, match=rf"^Invalid backend class: {AsyncBackend!r}$"): - AsyncBackendFactory.set_default_backend(AsyncBackend) + def hook(name: str) -> AsyncBackend: + if name == "mock": + return MockBackend(mocker) + raise UnsupportedOperation - @pytest.mark.parametrize("backend_name", list(BACKENDS)) - def test____extend____replace_by_a_subclass(self, backend_name: str) -> None: - # Arrange - class ExtendedBackend(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass + AsyncBackendFactory.push_factory_hook(hook) + case _: + assert_never(action) # Act - AsyncBackendFactory.extend(backend_name, ExtendedBackend) + backend = AsyncBackendFactory.get_backend("mock") # Assert - assert AsyncBackendFactory.get_all_backends(extended=True)[backend_name] is ExtendedBackend - assert AsyncBackendFactory.get_all_backends(extended=False)[backend_name] is self.BACKENDS[backend_name] + assert type(backend) is MockBackend + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) @pytest.mark.parametrize("backend_name", list(BACKENDS)) - @pytest.mark.parametrize("method", ["using_None", "using_base_cls"]) - def test____extend____remove_extension(self, backend_name: str, method: Literal["using_None", "using_base_cls"]) -> None: + def test____get_backend____override_backend( + self, + backend_name: str, + action: Literal["push_backend_factory", "push_factory_hook"], + mocker: MockerFixture, + ) -> None: # Arrange - class ExtendedBackend(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory(backend_name, lambda: MockBackend(mocker)) + case "push_factory_hook": - AsyncBackendFactory.extend(backend_name, ExtendedBackend) - assert AsyncBackendFactory.get_all_backends(extended=True)[backend_name] is ExtendedBackend + def hook(name: str) -> AsyncBackend: + if name == backend_name: + return MockBackend(mocker) + raise UnsupportedOperation - # Act - match method: - case "using_None": - AsyncBackendFactory.extend(backend_name, None) - case "using_base_cls": - AsyncBackendFactory.extend(backend_name, self.BACKENDS[backend_name]) + AsyncBackendFactory.push_factory_hook(hook) case _: - assert_never(method) + assert_never(action) + + # Act + backend = AsyncBackendFactory.get_backend(backend_name) # Assert - assert AsyncBackendFactory.get_all_backends(extended=True)[backend_name] is self.BACKENDS[backend_name] - assert AsyncBackendFactory.get_all_backends(extended=False)[backend_name] is self.BACKENDS[backend_name] + assert not isinstance(backend, self.BACKENDS[backend_name]) + assert isinstance(backend, MockBackend) - @pytest.mark.parametrize("backend_name", list(BACKENDS)) - def test____extend____error_invalid_class(self, backend_name: str) -> None: + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) + @pytest.mark.parametrize("invalid_cls", [int, Socket, TestAsyncBackend]) + def test____get_backend____error_do_not_derive_from_AsyncBackend( + self, + action: Literal["push_backend_factory", "push_factory_hook"], + invalid_cls: type[Any], + ) -> None: # Arrange - default_backend_cls = self.BACKENDS[backend_name] + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory("mock", invalid_cls) + case "push_factory_hook": - # Act & Assert - with pytest.raises( - TypeError, match=rf"^Invalid backend class \(not a subclass of {default_backend_cls!r}\): {MockBackend!r}$" - ): - AsyncBackendFactory.extend(backend_name, MockBackend) + def hook(name: str) -> AsyncBackend: + if name == "mock": + return invalid_cls() + raise UnsupportedOperation - @pytest.mark.parametrize("backend_name", list(BACKENDS)) - def test____extend____several_replacement(self, backend_name: str) -> None: - # Arrange - class ExtendedBackendV1(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass - - class ExtendedBackendV2(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass + AsyncBackendFactory.push_factory_hook(hook) + case _: + assert_never(action) # Act & Assert - AsyncBackendFactory.extend(backend_name, ExtendedBackendV1) - assert AsyncBackendFactory.get_all_backends(extended=True)[backend_name] is ExtendedBackendV1 - AsyncBackendFactory.extend(backend_name, ExtendedBackendV2) - assert AsyncBackendFactory.get_all_backends(extended=True)[backend_name] is ExtendedBackendV2 - - def test____get_default_backend____without_sniffio____returns_asyncio_backend(self) -> None: - # Arrange - AsyncBackendFactory.set_default_backend(None) - - # Act - backend_cls = AsyncBackendFactory.get_default_backend(guess_current_async_library=False) - - # Assert - assert backend_cls is FakeAsyncIOBackend + with pytest.raises(TypeError, match=r"^.+ did not return an AsyncBackend instance$"): + AsyncBackendFactory.get_backend("mock") - @pytest.mark.feature_sniffio @pytest.mark.parametrize("running_backend_name", list(BACKENDS)) - @pytest.mark.parametrize("extended", [False, True], ids=lambda extended: f"extended=={extended}") - def test____get_default_backend____with_sniffio____returns_running_backend( + def test____current____returns_running_backend( self, running_backend_name: str, - extended: bool, mocker: MockerFixture, ) -> None: # Arrange expected_cls = self.BACKENDS[running_backend_name] - if extended: - - class ExtendedBackend(self.BACKENDS[running_backend_name]): # type: ignore[name-defined,misc] - pass - - AsyncBackendFactory.extend(running_backend_name, ExtendedBackend) - expected_cls = ExtendedBackend mock_current_async_library: MagicMock = mocker.patch( - "sniffio.current_async_library", + self.CURRENT_ASYNC_LIBRARY_FUNC, autospec=True, return_value=running_backend_name, ) - AsyncBackendFactory.set_default_backend(None) # Act - backend_cls = AsyncBackendFactory.get_default_backend(guess_current_async_library=True) + backend = AsyncBackendFactory.current() # Assert mock_current_async_library.assert_called_once_with() - assert backend_cls is expected_cls + assert isinstance(backend, expected_cls) - @pytest.mark.feature_sniffio - def test____get_default_backend____with_sniffio____running_library_does_not_have_backend_implementation( + def test____current____running_library_does_not_have_backend_implementation( self, mocker: MockerFixture, ) -> None: # Arrange mock_current_async_library: MagicMock = mocker.patch( - "sniffio.current_async_library", + self.CURRENT_ASYNC_LIBRARY_FUNC, autospec=True, return_value="some_other_async_runner", ) - AsyncBackendFactory.set_default_backend(None) # Act & Assert - with pytest.raises(KeyError): - AsyncBackendFactory.get_default_backend(guess_current_async_library=True) + with pytest.raises( + NotImplementedError, + match=r"^Running library 'some_other_async_runner' misses the backend implementation$", + ): + AsyncBackendFactory.current() # Assert mock_current_async_library.assert_called_once_with() @pytest.mark.parametrize("backend_name", list(BACKENDS)) - @pytest.mark.parametrize("extended", [False, True], ids=lambda extended: f"extended=={extended}") - def test____new____instanciate_backend(self, backend_name: str, extended: bool) -> None: + def test____current____by_default(self, backend_name: str, mocker: MockerFixture) -> None: # Arrange - expected_cls = self.BACKENDS[backend_name] - if extended: + AsyncBackendFactory.remove_installed_hooks() + mocker.patch( + self.CURRENT_ASYNC_LIBRARY_FUNC, + autospec=True, + return_value=backend_name, + ) - class ExtendedBackend(self.BACKENDS[backend_name]): # type: ignore[name-defined,misc] - pass + # Act & Assert + match backend_name: + case "asyncio": + from easynetwork.lowlevel.std_asyncio import AsyncIOBackend + + assert type(AsyncBackendFactory.current()) is AsyncIOBackend + case _: + with pytest.raises(NotImplementedError, match=r"^Running library '.+' misses the backend implementation$"): + AsyncBackendFactory.current() - AsyncBackendFactory.extend(backend_name, ExtendedBackend) - expected_cls = ExtendedBackend + @pytest.mark.parametrize("backend_name", list(BACKENDS)) + def test____current____singleton(self, backend_name: str, mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + self.CURRENT_ASYNC_LIBRARY_FUNC, + autospec=True, + return_value=backend_name, + ) + first_backend_instance = AsyncBackendFactory.current() # Act - backend = AsyncBackendFactory.new(backend_name) + this_backend = AsyncBackendFactory.current() # Assert - assert isinstance(backend, expected_cls) + assert this_backend is first_backend_instance - def test____new____instanciate_default_backend(self, mocker: MockerFixture) -> None: + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) + def test____current____add_hook( + self, + action: Literal["push_backend_factory", "push_factory_hook"], + mocker: MockerFixture, + ) -> None: # Arrange - AsyncBackendFactory.set_default_backend(MockBackend) + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory("mock", lambda: MockBackend(mocker)) + case "push_factory_hook": + + def hook(name: str) -> AsyncBackend: + if name == "mock": + return MockBackend(mocker) + raise UnsupportedOperation + + AsyncBackendFactory.push_factory_hook(hook) + case _: + assert_never(action) + + mocker.patch( + self.CURRENT_ASYNC_LIBRARY_FUNC, + autospec=True, + return_value="mock", + ) # Act - backend = AsyncBackendFactory.new(mocker=mocker) + backend = AsyncBackendFactory.current() # Assert - assert isinstance(backend, MockBackend) + assert type(backend) is MockBackend - def test____ensure____return_given_backend_object(self, mocker: MockerFixture) -> None: + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) + @pytest.mark.parametrize("invalid_cls", [int, Socket, TestAsyncBackend]) + def test____current____error_do_not_derive_from_AsyncBackend( + self, + action: Literal["push_backend_factory", "push_factory_hook"], + invalid_cls: type[Any], + mocker: MockerFixture, + ) -> None: # Arrange - expected_backend = MockBackend(mocker) + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory("mock", invalid_cls) + case "push_factory_hook": - # Act - backend = AsyncBackendFactory.ensure(expected_backend) + def hook(name: str) -> AsyncBackend: + if name == "mock": + return invalid_cls() + raise UnsupportedOperation - # Assert - assert backend is expected_backend + AsyncBackendFactory.push_factory_hook(hook) + case _: + assert_never(action) + + mocker.patch( + self.CURRENT_ASYNC_LIBRARY_FUNC, + autospec=True, + return_value="mock", + ) + + # Act & Assert + with pytest.raises(TypeError, match=r"^.+ did not return an AsyncBackend instance$"): + AsyncBackendFactory.current() + @pytest.mark.parametrize("action", ["push_backend_factory", "push_factory_hook"]) @pytest.mark.parametrize("backend_name", list(BACKENDS)) - @pytest.mark.parametrize("backend_kwargs", [None, {}], ids=repr) - def test____ensure____instanciate_backend_from_name(self, backend_name: str, backend_kwargs: dict[str, Any] | None) -> None: + def test____current____override_backend( + self, + backend_name: str, + action: Literal["push_backend_factory", "push_factory_hook"], + mocker: MockerFixture, + ) -> None: # Arrange + match action: + case "push_backend_factory": + AsyncBackendFactory.push_backend_factory(backend_name, lambda: MockBackend(mocker)) + case "push_factory_hook": - # Act - backend = AsyncBackendFactory.ensure(backend_name, backend_kwargs) + def hook(name: str) -> AsyncBackend: + if name == backend_name: + return MockBackend(mocker) + raise UnsupportedOperation - # Assert - assert isinstance(backend, self.BACKENDS[backend_name]) + AsyncBackendFactory.push_factory_hook(hook) + case _: + assert_never(action) - def test____ensure____instanciate_default_backend(self, mocker: MockerFixture) -> None: - # Arrange - AsyncBackendFactory.set_default_backend(MockBackend) + mocker.patch( + self.CURRENT_ASYNC_LIBRARY_FUNC, + autospec=True, + return_value=backend_name, + ) # Act - backend = AsyncBackendFactory.ensure(None, {"mocker": mocker}) + backend = AsyncBackendFactory.current() # Assert + assert not isinstance(backend, self.BACKENDS[backend_name]) assert isinstance(backend, MockBackend) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py index d88720d6..cff2f963 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py @@ -2,12 +2,14 @@ import concurrent.futures import contextvars +from collections.abc import Iterator from typing import TYPE_CHECKING from easynetwork.lowlevel.api_async.backend.futures import AsyncExecutor import pytest +from .....tools import temporary_backend from ...._utils import partial_eq if TYPE_CHECKING: @@ -32,8 +34,13 @@ def executor_handle_contexts(request: pytest.FixtureRequest) -> bool: @pytest.fixture @staticmethod - def executor(mock_backend: MagicMock, mock_stdlib_executor: MagicMock, executor_handle_contexts: bool) -> AsyncExecutor: - return AsyncExecutor(mock_stdlib_executor, mock_backend, handle_contexts=executor_handle_contexts) + def executor( + mock_backend: MagicMock, + mock_stdlib_executor: MagicMock, + executor_handle_contexts: bool, + ) -> Iterator[AsyncExecutor]: + with temporary_backend(mock_backend): + yield AsyncExecutor(mock_stdlib_executor, handle_contexts=executor_handle_contexts) @pytest.fixture(autouse=True) @staticmethod @@ -42,7 +49,6 @@ def mock_contextvars_copy_context(mocker: MockerFixture) -> MagicMock: async def test____dunder_init____invalid_executor( self, - mock_backend: MagicMock, mocker: MockerFixture, ) -> None: # Arrange @@ -50,7 +56,7 @@ async def test____dunder_init____invalid_executor( # Act & Assert with pytest.raises(TypeError): - _ = AsyncExecutor(invalid_executor, mock_backend) + _ = AsyncExecutor(invalid_executor) async def test____run____submit_to_executor_and_wait( self, diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py index f7b33841..6a52e683 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py @@ -3,6 +3,7 @@ import asyncio import contextlib import operator +from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any from easynetwork.exceptions import BusyResourceError @@ -10,6 +11,9 @@ from easynetwork.lowlevel.api_async.transports.abc import AsyncDatagramListener import pytest +import pytest_asyncio + +from .....tools import temporary_backend if TYPE_CHECKING: from unittest.mock import MagicMock @@ -44,14 +48,15 @@ def make_datagram_side_effect(packet: Any) -> bytes: # mock_datagram_protocol.build_packet_from_datagram.side_effect = build_packet_from_datagram_side_effect return mock_datagram_protocol - @pytest.fixture + @pytest_asyncio.fixture @staticmethod - def server( + async def server( mock_datagram_listener: MagicMock, mock_datagram_protocol: MagicMock, mock_backend: MagicMock, - ) -> AsyncDatagramServer[Any, Any, Any]: - return AsyncDatagramServer(mock_datagram_listener, mock_datagram_protocol, backend=mock_backend) + ) -> AsyncIterator[AsyncDatagramServer[Any, Any, Any]]: + with temporary_backend(mock_backend): + yield AsyncDatagramServer(mock_datagram_listener, mock_datagram_protocol) async def test____dunder_init____invalid_transport( self, @@ -138,21 +143,11 @@ async def test____send_packet_to____send_bytes_to_transport( mock_datagram_protocol.make_datagram.assert_called_once_with(mocker.sentinel.packet) mock_datagram_listener.send_to.assert_awaited_once_with(b"packet", mocker.sentinel.destination) - async def test____get_backend____default( - self, - server: AsyncDatagramServer[Any, Any, Any], - mock_backend: MagicMock, - ) -> None: - # Arrange - - # Act & Assert - assert server.get_backend() is mock_backend - class TestClientManager: @pytest.fixture @staticmethod - def mock_backend(mock_backend: MagicMock, mocker: MockerFixture) -> MagicMock: + def mock_backend(mock_backend: MagicMock) -> MagicMock: mock_backend.create_condition_var.side_effect = asyncio.Condition return mock_backend diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py index f035087c..c38f01c1 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Generator +from collections.abc import Awaitable, Callable, Generator, Iterator from typing import TYPE_CHECKING, Any from easynetwork.lowlevel._stream import StreamDataProducer @@ -10,6 +10,8 @@ import pytest +from .....tools import temporary_backend + if TYPE_CHECKING: from unittest.mock import MagicMock @@ -138,8 +140,9 @@ def server( mock_stream_protocol: MagicMock, max_recv_size: int, mock_backend: MagicMock, - ) -> AsyncStreamServer[Any, Any]: - return AsyncStreamServer(mock_listener, mock_stream_protocol, max_recv_size, backend=mock_backend) + ) -> Iterator[AsyncStreamServer[Any, Any]]: + with temporary_backend(mock_backend): + yield AsyncStreamServer(mock_listener, mock_stream_protocol, max_recv_size) async def test____dunder_init____invalid_transport( self, @@ -243,13 +246,3 @@ async def serve_side_effect(handler: Callable[[Any], Awaitable[None]], task_grou with pytest.raises(TypeError, match=r"^Expected an AsyncStreamTransport object, got .*$"): await server.serve(client_connected_cb, tg) client_connected_cb.assert_not_called() - - async def test____get_backend____default( - self, - server: AsyncStreamServer[Any, Any], - mock_backend: MagicMock, - ) -> None: - # Arrange - - # Act & Assert - assert server.get_backend() is mock_backend diff --git a/tox.ini b/tox.ini index e010a0b6..d3bd4b47 100644 --- a/tox.ini +++ b/tox.ini @@ -8,11 +8,11 @@ envlist = build # Tests (3.11) py311-other-{tests,docstrings} - py311-{unit,functional}-{__standard__,cbor,msgpack,encryption,sniffio} - py311-functional-{asyncio_proactor,uvloop} + py311-{unit,functional}-{__standard__,cbor,msgpack,encryption} + py311-functional-{sniffio,asyncio_proactor,uvloop} # Tests (3.12) - py312-{unit,functional}-{__standard__,cbor,msgpack,encryption,sniffio} - py312-functional-{asyncio_proactor,uvloop} + py312-{unit,functional}-{__standard__,cbor,msgpack,encryption} + py312-functional-{sniffio,asyncio_proactor,uvloop} # Report coverage skip_missing_interpreters = true @@ -51,14 +51,13 @@ commands = docstrings: pytest --doctest-modules {posargs} {[docs]examples_dir}{/}tutorials{/}ftp_server docstrings: pytest --doctest-glob="*.rst" {posargs} {[docs]source_dir} -[testenv:{py311,py312}-{unit,functional}-{__standard__,cbor,msgpack,encryption,sniffio}] +[testenv:{py311,py312}-{unit,functional}-{__standard__,cbor,msgpack,encryption}] package = wheel groups = test cbor: cbor msgpack: msgpack encryption: encryption - sniffio: sniffio setenv = {[base]setenv} {[base-pytest]setenv} @@ -73,7 +72,20 @@ commands = cbor: pytest -m "feature_cbor" {posargs} {env:TESTS_ROOTDIR} msgpack: pytest -m "feature_msgpack" {posargs} {env:TESTS_ROOTDIR} encryption: pytest -m "feature_encryption" {posargs} {env:TESTS_ROOTDIR} - sniffio: pytest -m "feature_sniffio" {posargs} {env:TESTS_ROOTDIR} + +[testenv:{py311,py312}-functional-sniffio] +package = wheel +groups = + test + sniffio +setenv = + {[base]setenv} + {[base-pytest]setenv} + PYTEST_ADDOPTS = {[base-pytest]addopts} --cov --cov-report='' + COVERAGE_FILE = .coverage.{envname} + TESTS_ROOTDIR = {[base-pytest]functional_tests_rootdir} +commands = + pytest -m "feature_sniffio" {posargs} {env:TESTS_ROOTDIR} [testenv:{py311,py312}-functional-{asyncio_proactor,uvloop}] package = wheel @@ -99,8 +111,8 @@ commands = [testenv:coverage] skip_install = True depends = - {py311,py312}-{unit,functional}-{__standard__,cbor,msgpack,encryption,sniffio} - {py311,py312}-functional-{asyncio_proactor,uvloop} + {py311,py312}-{unit,functional}-{__standard__,cbor,msgpack,encryption} + {py311,py312}-functional-{sniffio,asyncio_proactor,uvloop} parallel_show_output = True groups = coverage