From 6ffd5d76149c90b8ed6ad2678bbb840219748f50 Mon Sep 17 00:00:00 2001 From: PabloLec Date: Fri, 27 Sep 2024 18:24:19 +0200 Subject: [PATCH] Fix integration test fictures --- tests/integration/test_full_workflow.py | 43 +++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_full_workflow.py b/tests/integration/test_full_workflow.py index 1a0a4d6..d197582 100644 --- a/tests/integration/test_full_workflow.py +++ b/tests/integration/test_full_workflow.py @@ -1,3 +1,7 @@ +import contextvars +import functools +import traceback +from asyncio import DefaultEventLoopPolicy, Task from pathlib import Path from unittest.mock import MagicMock @@ -23,10 +27,43 @@ ) -@pytest.mark.asyncio(scope="class") +def task_factory(loop, coro, context=None): + stack = traceback.extract_stack() + for frame in stack[-2::-1]: + package_name = Path(frame.filename).parts[-2] + if package_name != "asyncio": + if package_name == "pytest_asyncio": + # This function was called from pytest_asyncio, use shared context + break + else: + # This function was called from somewhere else, create context copy + context = None + break + return Task(coro, loop=loop, context=context) + + +class CustomEventLoopPolicy(DefaultEventLoopPolicy): + def __init__(self, context) -> None: + super().__init__() + self.context = context + + def new_event_loop(self): + loop = self._loop_factory() + loop.set_task_factory(functools.partial(task_factory, context=self.context)) + return loop + + +@pytest.fixture(scope="session") +def event_loop_policy(): + policy = CustomEventLoopPolicy(contextvars.copy_context()) + yield policy + policy.get_event_loop().close() + + +@pytest.mark.asyncio(loop_scope="class") @pytest.mark.incremental class TestFullWorkflow: - @pytest_asyncio.fixture(scope="class") + @pytest_asyncio.fixture(scope="class", loop_scope="class") def session_patch(self, session_mocker): session_mocker.patch( "recoverpy.lib.env_check._is_user_root", @@ -41,7 +78,7 @@ def session_patch(self, session_mocker): MagicMock(return_value=True), ) - @pytest_asyncio.fixture(scope="class") + @pytest_asyncio.fixture(scope="class", loop_scope="class") async def pilot(self, session_patch): async with RecoverpyApp().run_test() as p: yield p