diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 564966ed..64de2530 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -78,19 +78,32 @@ def in_pending_state(method: F) -> F: @functools.wraps(method) async def wrapper(self, *args, **kwargs): """Create a future for the decorated method.""" - if self._attempted_start or not self._ready: - self._ready = _get_future() + # Initialize the ready_count to 0 if it doesn't exist + if self.owns_kernel: + if not hasattr(self, "_ready_count"): + self._ready_count = 0 + + if self._ready_count == 0: + self._ready = _get_future() + + self._ready_count += 1 + try: # call wrapped method, await, and set the result or exception. out = await method(self, *args, **kwargs) # Add a small sleep to ensure tests can capture the state before done await asyncio.sleep(0.01) if self.owns_kernel: - self._ready.set_result(None) + self._ready_count -= 1 + if self._ready_count == 0: + self._ready.set_result(None) return out except Exception as e: - self._ready.set_exception(e) - self.log.exception(self._ready.exception()) + if self.owns_kernel: + self._ready_count -= 1 + if self._ready_count == 0: + self._ready.set_exception(e) + self.log.exception(e) raise e return t.cast(F, wrapper) @@ -109,7 +122,6 @@ def __init__(self, *args, **kwargs): self._owns_kernel = kwargs.pop("owns_kernel", True) super().__init__(**kwargs) self._shutdown_status = _ShutdownStatus.Unset - self._attempted_start = False self._ready = None _created_context: Bool = Bool(False) @@ -397,7 +409,6 @@ async def _async_start_kernel(self, **kw: t.Any) -> None: keyword arguments that are passed down to build the kernel_cmd and launching the kernel (e.g. Popen kwargs). """ - self._attempted_start = True kernel_cmd, kw = await self._async_pre_start_kernel(**kw) # launch the kernel subprocess diff --git a/tests/test_manager.py b/tests/test_manager.py index e3d6ea22..1e3d2a32 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,6 +1,7 @@ """Tests for KernelManager""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import asyncio import os import tempfile from unittest import mock @@ -32,3 +33,14 @@ def test_connection_file_real_path(): km._launch_args = {} cmds = km.format_kernel_cmd() assert cmds[4] == "foobar" + + +async def test_in_pending_state(): + """Verify in_pending_state race condition""" + tm = KernelManager() + start_kernel = asyncio.ensure_future(tm._async_start_kernel()) + shutdown_kernel = asyncio.ensure_future(tm._async_shutdown_kernel()) + + await start_kernel + await shutdown_kernel + assert tm.is_alive() is False