diff --git a/src/easynetwork/lowlevel/api_async/backend/_sniffio_helpers.py b/src/easynetwork/lowlevel/api_async/backend/_sniffio_helpers.py new file mode 100644 index 00000000..4bd8b5e1 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_sniffio_helpers.py @@ -0,0 +1,52 @@ +# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Helper module for sniffio integration""" + +from __future__ import annotations + +__all__ = ["current_async_library", "setup_sniffio_contextvar"] + +import contextvars +import sys + + +def current_async_library() -> str: + try: + from sniffio import current_async_library + except ModuleNotFoundError: + current_async_library = _current_async_library_fallback + return current_async_library() + + +def _current_async_library_fallback() -> str: + if "asyncio" in sys.modules: + import asyncio + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + return "asyncio" + raise RuntimeError("unknown async library, or not in async context") + + +def setup_sniffio_contextvar(context: contextvars.Context, library_name: str | None, /) -> None: + try: + import sniffio + except ModuleNotFoundError: + pass + else: + context.run(sniffio.current_async_library_cvar.set, library_name) diff --git a/src/easynetwork/lowlevel/api_async/backend/factory.py b/src/easynetwork/lowlevel/api_async/backend/factory.py index cead771a..0b3a11e1 100644 --- a/src/easynetwork/lowlevel/api_async/backend/factory.py +++ b/src/easynetwork/lowlevel/api_async/backend/factory.py @@ -25,8 +25,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Final, final +from ._sniffio_helpers import current_async_library as _sniffio_current_async_library from .abc import AsyncBackend -from .sniffio import current_async_library as _sniffio_current_async_library if TYPE_CHECKING: from importlib.metadata import EntryPoint diff --git a/src/easynetwork/lowlevel/api_async/backend/futures.py b/src/easynetwork/lowlevel/api_async/backend/futures.py index d47c359e..8b76831e 100644 --- a/src/easynetwork/lowlevel/api_async/backend/futures.py +++ b/src/easynetwork/lowlevel/api_async/backend/futures.py @@ -25,8 +25,8 @@ from collections.abc import AsyncGenerator, Callable, Iterable, Mapping from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar +from . import _sniffio_helpers from .factory import AsyncBackendFactory -from .sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar if TYPE_CHECKING: from types import TracebackType @@ -206,10 +206,7 @@ async def shutdown(self, *, cancel_futures: bool = False) -> None: def _setup_func(self, func: Callable[_P, _T]) -> Callable[_P, _T]: if self.__handle_contexts: ctx = contextvars.copy_context() - - if _sniffio_current_async_library_cvar is not None: - ctx.run(_sniffio_current_async_library_cvar.set, None) - + _sniffio_helpers.setup_sniffio_contextvar(ctx, None) func = functools.partial(ctx.run, func) return func diff --git a/src/easynetwork/lowlevel/api_async/backend/sniffio.py b/src/easynetwork/lowlevel/api_async/backend/sniffio.py deleted file mode 100644 index a8b1008b..00000000 --- a/src/easynetwork/lowlevel/api_async/backend/sniffio.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -"""Helper module for sniffio integration""" - -from __future__ import annotations - -__all__ = ["current_async_library", "current_async_library_cvar"] - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from contextvars import ContextVar - - -current_async_library_cvar: ContextVar[str | None] | None - -try: - import sniffio - - def current_async_library() -> str: - return sniffio.current_async_library() - -except ImportError: - - def current_async_library() -> str: - return "asyncio" - - -try: - from sniffio import current_async_library_cvar -except ImportError: - current_async_library_cvar = None diff --git a/src/easynetwork/lowlevel/asyncio/backend.py b/src/easynetwork/lowlevel/asyncio/backend.py index 294367cb..ee64fad2 100644 --- a/src/easynetwork/lowlevel/asyncio/backend.py +++ b/src/easynetwork/lowlevel/asyncio/backend.py @@ -40,8 +40,8 @@ del _ssl from ...exceptions import UnsupportedOperation +from ..api_async.backend import _sniffio_helpers from ..api_async.backend.abc import AsyncBackend as AbstractAsyncBackend -from ..api_async.backend.sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar from ._asyncio_utils import ( create_connection, create_datagram_connection, @@ -391,8 +391,7 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg loop = asyncio.get_running_loop() ctx = contextvars.copy_context() - if _sniffio_current_async_library_cvar is not None: - ctx.run(_sniffio_current_async_library_cvar.set, None) + _sniffio_helpers.setup_sniffio_contextvar(ctx, None) future = loop.run_in_executor(None, functools.partial(ctx.run, func, *args, **kwargs)) try: diff --git a/src/easynetwork/lowlevel/asyncio/threads.py b/src/easynetwork/lowlevel/asyncio/threads.py index f863c0ef..302ea44d 100644 --- a/src/easynetwork/lowlevel/asyncio/threads.py +++ b/src/easynetwork/lowlevel/asyncio/threads.py @@ -28,8 +28,8 @@ from typing import TYPE_CHECKING, ParamSpec, Self, TypeVar, final from .. import _lock, _utils +from ..api_async.backend import _sniffio_helpers from ..api_async.backend.abc import ThreadsPortal as AbstractThreadsPortal -from ..api_async.backend.sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar from .tasks import TaskUtils if TYPE_CHECKING: @@ -146,9 +146,7 @@ def callback() -> None: waiter = self.__register_waiter(self.__call_soon_waiters, loop) ctx = contextvars.copy_context() - - if _sniffio_current_async_library_cvar is not None: - ctx.run(_sniffio_current_async_library_cvar.set, "asyncio") + _sniffio_helpers.setup_sniffio_contextvar(ctx, "asyncio") future: concurrent.futures.Future[_T] = concurrent.futures.Future() diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py index bc24cad3..d88720d6 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_futures.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from easynetwork.lowlevel.api_async.backend.futures import AsyncExecutor -from easynetwork.lowlevel.api_async.backend.sniffio import current_async_library_cvar import pytest @@ -85,10 +84,6 @@ async def test____run____submit_to_executor_and_wait( # Assert if executor_handle_contexts: mock_contextvars_copy_context.assert_called_once_with() - if current_async_library_cvar is not None: - mock_context.run.assert_called_once_with(current_async_library_cvar.set, None) - else: - mock_context.run.assert_not_called() mock_stdlib_executor.submit.assert_called_once_with( partial_eq(mock_context.run, func), mocker.sentinel.arg1, @@ -98,7 +93,6 @@ async def test____run____submit_to_executor_and_wait( ) else: mock_contextvars_copy_context.assert_not_called() - mock_context.run.assert_not_called() mock_stdlib_executor.submit.assert_called_once_with( func, mocker.sentinel.arg1, @@ -152,12 +146,6 @@ async def test____map____submit_to_executor_and_wait( # Assert if executor_handle_contexts: assert mock_contextvars_copy_context.call_args_list == [mocker.call() for _ in range(len(mock_contexts))] - if current_async_library_cvar is not None: - for mock_context in mock_contexts: - mock_context.run.assert_called_once_with(current_async_library_cvar.set, None) - else: - for mock_context in mock_contexts: - mock_context.run.assert_not_called() assert mock_stdlib_executor.submit.call_args_list == [ mocker.call(partial_eq(mock_context.run, func), arg) for mock_context, arg in zip(mock_contexts, func_args) ] diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_sniffio_helpers.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_sniffio_helpers.py new file mode 100644 index 00000000..53d73eb8 --- /dev/null +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_sniffio_helpers.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import sys + +from easynetwork.lowlevel.api_async.backend._sniffio_helpers import _current_async_library_fallback + +import pytest + + +def test____current_async_library_fallback____asyncio_not_imported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Arrange + monkeypatch.delitem(sys.modules, "asyncio", raising=False) + + # Act + with pytest.raises(RuntimeError, match=r"^unknown async library, or not in async context$"): + _current_async_library_fallback() + + # Assert + assert "asyncio" not in sys.modules + + +def test____current_async_library_fallback____asyncio_not_running() -> None: + # Arrange + import asyncio + + del asyncio + + # Act & Assert + with pytest.raises(RuntimeError, match=r"^unknown async library, or not in async context$"): + _current_async_library_fallback() + + +def test____current_async_library_fallback____asyncio_is_running() -> None: + # Arrange + import asyncio + + async def main() -> str: + return _current_async_library_fallback() + + # Act + library_name = asyncio.run(main()) + + # Assert + assert library_name == "asyncio"