diff --git a/jupyter_ui_poll/_poll.py b/jupyter_ui_poll/_poll.py index 81fb11e..5fb431d 100644 --- a/jupyter_ui_poll/_poll.py +++ b/jupyter_ui_poll/_poll.py @@ -3,7 +3,7 @@ import time from collections import abc from functools import singledispatch -from inspect import iscoroutinefunction +from inspect import iscoroutinefunction, isawaitable from typing import ( Any, AsyncIterable, @@ -46,9 +46,8 @@ def __init__(self, shell, loop) -> None: self._events: List[Tuple[Any, Any, Any]] = [] self._backup_execute_request = kernel.shell_handlers["execute_request"] self._aproc = None - self._kernel_is_async = iscoroutinefunction(self._backup_execute_request) - if self._kernel_is_async: # ipykernel 6+ + if iscoroutinefunction(self._backup_execute_request): # ipykernel 6+ kernel.shell_handlers["execute_request"] = self._execute_request_async else: # ipykernel < 6 @@ -90,7 +89,7 @@ async def replay(self): kernel._send_abort_reply(stream, parent, ident) else: rr = kernel.execute_request(stream, ident, parent) - if self._kernel_is_async: + if isawaitable(rr): await rr # replicate shell_dispatch behaviour