diff --git a/src/easynetwork_asyncio/threads.py b/src/easynetwork_asyncio/threads.py index 08f96e00..90ebcf05 100644 --- a/src/easynetwork_asyncio/threads.py +++ b/src/easynetwork_asyncio/threads.py @@ -85,7 +85,21 @@ def schedule_task() -> concurrent.futures.Future[_T]: future: concurrent.futures.Future[_T] = concurrent.futures.Future() async def coroutine() -> None: + def on_fut_done(future: concurrent.futures.Future[_T]) -> None: + if future.cancelled(): + try: + self.run_sync(task.cancel) + except RuntimeError: + # on_fut_done() called from coroutine() + # or the portal is already shut down + pass + + task = TaskUtils.current_asyncio_task() try: + if future.cancelled(): + task.cancel() + else: + future.add_done_callback(on_fut_done) result = await coro_func(*args, **kwargs) except asyncio.CancelledError: future.cancel() @@ -102,20 +116,9 @@ async def coroutine() -> None: task = self.__task_group.create_task(coroutine()) loop = task.get_loop() + del task with self.__lock.get(): loop.call_soon(self.__register_waiter(self.__call_soon_waiters, loop).set_result, None) - - def on_fut_done(future: concurrent.futures.Future[_T]) -> None: - if future.cancelled(): - try: - self.run_sync(task.cancel) - except RuntimeError: - # on_fut_done() called from coroutine() - # or the portal is already shut down - pass - - future.add_done_callback(on_fut_done) - return future return self.run_sync(schedule_task) diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index 78bee52b..ec837ab3 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -973,6 +973,43 @@ def thread() -> None: cancellation_ignored.assert_called_once() + async def test____create_threads_portal____run_coroutine_soon____future_cancelled_before_await( + self, + event_loop: asyncio.AbstractEventLoop, + backend: AsyncIOBackend, + ) -> None: + checkpoints: list[str] = [] + + async def coroutine() -> None: + current_task = asyncio.current_task() + assert current_task is not None + + checkpoints.append(f"{current_task.cancelling()=}") + await asyncio.sleep(0) + checkpoints.append("does-not-raise-CancelledError") + + def thread() -> None: + future = threads_portal.run_coroutine_soon(coroutine) + future.cancel() + + wait_concurrent_futures({future}, timeout=5) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + event_loop_slowdown_handle: asyncio.Handle + + def event_loop_slowdown() -> None: # Drastically slow down event loop + nonlocal event_loop_slowdown_handle + + time.sleep(0.5) + event_loop_slowdown_handle = event_loop.call_soon(event_loop_slowdown) + + event_loop_slowdown_handle = event_loop.call_soon(event_loop_slowdown) + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + event_loop_slowdown_handle.cancel() + assert checkpoints == ["current_task.cancelling()=1"] + async def test____create_threads_portal____context_exit____wait_scheduled_call_soon( self, event_loop: asyncio.AbstractEventLoop,