diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py index e24a8a2a..92ad3ff3 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py @@ -265,7 +265,13 @@ def create_condition_var(self, lock: ILock | None = None) -> ICondition: return self.__asyncio.Condition(lock) - async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + async def run_in_thread( + self, + func: Callable[[*_T_PosArgs], _T], + /, + *args: *_T_PosArgs, + abandon_on_cancel: bool = False, + ) -> _T: import sniffio loop = self.__asyncio.get_running_loop() @@ -273,8 +279,11 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg ctx.run(sniffio.current_async_library_cvar.set, None) - cb = functools.partial(ctx.run, func, *args, **kwargs) - return await self.__cancel_shielded_await(loop.run_in_executor(None, cb)) + cb = functools.partial(ctx.run, func, *args) + if abandon_on_cancel: + return await loop.run_in_executor(None, cb) + else: + return await self.__cancel_shielded_await(loop.run_in_executor(None, cb)) def create_threads_portal(self) -> ThreadsPortal: from .threads import ThreadsPortal diff --git a/src/easynetwork/lowlevel/api_async/backend/abc.py b/src/easynetwork/lowlevel/api_async/backend/abc.py index 1978288b..a22b7a52 100644 --- a/src/easynetwork/lowlevel/api_async/backend/abc.py +++ b/src/easynetwork/lowlevel/api_async/backend/abc.py @@ -1085,7 +1085,13 @@ def create_condition_var(self, lock: ILock | None = ...) -> ICondition: raise NotImplementedError @abstractmethod - async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> _T: + async def run_in_thread( + self, + func: Callable[[*_T_PosArgs], _T], + /, + *args: *_T_PosArgs, + abandon_on_cancel: bool = ..., + ) -> _T: """ Executes a synchronous function in a worker thread. @@ -1096,7 +1102,11 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg Cancellation handling: Because there is no way to "cancel" an arbitrary function call in an OS thread, - once the job is started, any cancellation requests will be discarded. + once the job is started: + + * If `abandon_on_cancel` is False (the default), any cancellation requests will be discarded. + + * If `abandon_on_cancel` is True, the task will notify the thread to stop (if possible) then will bail out. Warning: Due to the current coroutine implementation, `func` should not raise a :exc:`StopIteration`. @@ -1104,14 +1114,15 @@ async def run_in_thread(self, func: Callable[_P, _T], /, *args: _P.args, **kwarg Parameters: func: A synchronous function. - args: Positional arguments to be passed to `func`. - kwargs: Keyword arguments to be passed to `func`. + args: Positional arguments to be passed to `func`. If you need to pass keyword arguments, + then use :func:`functools.partial`. + abandon_on_cancel: Whether or not to abort task on cancellation request. Raises: - Exception: Whatever ``func(*args, **kwargs)`` raises. + Exception: Whatever ``func(*args)`` raises. Returns: - Whatever ``func(*args, **kwargs)`` returns. + Whatever ``func(*args)`` returns. """ raise NotImplementedError diff --git a/src/easynetwork/lowlevel/futures.py b/src/easynetwork/lowlevel/futures.py index 36c6db9e..7bb4407c 100644 --- a/src/easynetwork/lowlevel/futures.py +++ b/src/easynetwork/lowlevel/futures.py @@ -199,7 +199,8 @@ async def shutdown(self, *, cancel_futures: bool = False) -> None: has not started running. Any futures that are completed or running won't be cancelled, regardless of the value of `cancel_futures`. """ - await self.__backend.run_in_thread(self.__executor.shutdown, wait=True, cancel_futures=cancel_futures) + shutdown_callback = functools.partial(self.__executor.shutdown, wait=True, cancel_futures=cancel_futures) + await self.__backend.run_in_thread(shutdown_callback) def _setup_func(self, func: Callable[_P, _T]) -> Callable[_P, _T]: if self.__handle_contexts: diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index 9367af26..b7a912e3 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -723,7 +723,7 @@ async def coroutine() -> None: except* FutureException: pass - async def test____run_in_thread____cannot_be_cancelled( + async def test____run_in_thread____cannot_be_cancelled_by_default( self, backend: AsyncIOBackend, ) -> None: @@ -738,6 +738,19 @@ async def test____run_in_thread____cannot_be_cancelled( assert not task.cancelled() + async def test____run_in_thread____abandon_on_cancel( + self, + backend: AsyncIOBackend, + ) -> None: + event_loop = asyncio.get_running_loop() + task = asyncio.create_task(backend.run_in_thread(time.sleep, 0.5, abandon_on_cancel=True)) + event_loop.call_later(0.1, task.cancel) + + with pytest.raises(asyncio.CancelledError): + await task + + assert task.cancelled() + async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncIOBackend) -> None: sniffio.current_async_library_cvar.set("asyncio") diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py index d22565ee..bb5fecf5 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py @@ -980,8 +980,10 @@ async def test____create_condition_var____use_asyncio_Condition_class( mock_Condition.assert_called_once_with(mock_lock) assert condition is mocker.sentinel.condition_var + @pytest.mark.parametrize("abandon_on_cancel", [False, True], ids=lambda p: f"abandon_on_cancel=={p}") async def test____run_in_thread____use_loop_run_in_executor( self, + abandon_on_cancel: bool, backend: AsyncIOBackend, mocker: MockerFixture, ) -> None: @@ -999,8 +1001,7 @@ async def test____run_in_thread____use_loop_run_in_executor( func_stub, mocker.sentinel.arg1, mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, + abandon_on_cancel=abandon_on_cancel, ) # Assert @@ -1012,8 +1013,6 @@ async def test____run_in_thread____use_loop_run_in_executor( func_stub, mocker.sentinel.arg1, mocker.sentinel.arg2, - kw1=mocker.sentinel.kwargs1, - kw2=mocker.sentinel.kwargs2, ), ) func_stub.assert_not_called() diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/_fake_backends.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/_fake_backends.py index 05d8dfa9..ffa6e4a0 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/_fake_backends.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/_fake_backends.py @@ -79,15 +79,13 @@ def create_event(self) -> Any: def create_condition_var(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError + @no_type_check async def run_in_thread(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError def create_threads_portal(self) -> Any: raise NotImplementedError - async def wait_future(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError - @final class FakeAsyncIOBackend(BaseFakeBackend): diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_futures.py b/tests/unit_test/test_async/test_lowlevel_api/test_futures.py index 19b2c4a4..b614e4bc 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_futures.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_futures.py @@ -238,7 +238,9 @@ async def test____shutdown____shutdown_executor( # Assert mock_stdlib_executor.shutdown.assert_not_called() - mock_backend.run_in_thread.assert_awaited_once_with(mock_stdlib_executor.shutdown, wait=True, cancel_futures=False) + mock_backend.run_in_thread.assert_awaited_once_with( + partial_eq(mock_stdlib_executor.shutdown, wait=True, cancel_futures=False) + ) @pytest.mark.parametrize("cancel_futures", [False, True]) async def test____shutdown____shutdown_executor____cancel_futures( @@ -257,9 +259,7 @@ async def test____shutdown____shutdown_executor____cancel_futures( # Assert mock_stdlib_executor.shutdown.assert_not_called() mock_backend.run_in_thread.assert_awaited_once_with( - mock_stdlib_executor.shutdown, - wait=True, - cancel_futures=cancel_futures, + partial_eq(mock_stdlib_executor.shutdown, wait=True, cancel_futures=cancel_futures) ) async def test____context_manager____shutdown_executor_at_end( @@ -280,4 +280,6 @@ async def test____context_manager____shutdown_executor_at_end( # Assert mock_stdlib_executor.shutdown.assert_not_called() - mock_backend.run_in_thread.assert_awaited_once_with(mock_stdlib_executor.shutdown, wait=True, cancel_futures=False) + mock_backend.run_in_thread.assert_awaited_once_with( + partial_eq(mock_stdlib_executor.shutdown, wait=True, cancel_futures=False) + )