From bd10a20fb0b9c16204313e01be7b2116e50b0858 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Sun, 15 Oct 2023 21:09:39 -0500 Subject: [PATCH 1/3] Use strict typing --- docs/conf.py | 2 +- jupyter_client/asynchronous/client.py | 5 +- jupyter_client/blocking/client.py | 6 +- jupyter_client/channels.py | 2 +- jupyter_client/channelsabc.py | 14 ++--- jupyter_client/client.py | 4 +- jupyter_client/clientabc.py | 39 ++++++++----- jupyter_client/connect.py | 15 +++-- jupyter_client/consoleapp.py | 2 +- jupyter_client/ioloop/manager.py | 16 +++--- jupyter_client/ioloop/restarter.py | 9 +-- jupyter_client/jsonutil.py | 12 ++-- jupyter_client/kernelapp.py | 5 +- jupyter_client/kernelspec.py | 76 ++++++++++++++++---------- jupyter_client/kernelspecapp.py | 37 +++++++------ jupyter_client/localinterfaces.py | 52 +++++++++--------- jupyter_client/manager.py | 6 +- jupyter_client/managerabc.py | 19 ++++--- jupyter_client/multikernelmanager.py | 50 ++++++++--------- jupyter_client/provisioning/factory.py | 4 +- jupyter_client/restarter.py | 19 ++++--- jupyter_client/runapp.py | 13 +++-- jupyter_client/session.py | 32 +++++------ jupyter_client/ssh/forward.py | 2 +- jupyter_client/ssh/tunnel.py | 73 +++++++++++++++++++------ jupyter_client/threaded.py | 19 ++++--- jupyter_client/utils.py | 7 ++- jupyter_client/win_interrupt.py | 12 ++-- pyproject.toml | 21 ++----- tests/test_kernelspecapp.py | 4 +- tests/test_localinterfaces.py | 2 +- 31 files changed, 330 insertions(+), 249 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1054707c0..bee5430cd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -329,7 +329,7 @@ def filter(self, record: pylogging.LogRecord) -> bool: intersphinx_mapping = {'ipython': ('http://ipython.readthedocs.io/en/stable/', None)} -def setup(app): +def setup(app: object) -> None: HERE = osp.abspath(osp.dirname(__file__)) dest = osp.join(HERE, 'changelog.md') shutil.copy(osp.join(HERE, '..', 'CHANGELOG.md'), dest) diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 7e8216750..53c68ffba 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -1,6 +1,7 @@ """Implements an async kernel client""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import typing as t import zmq.asyncio from traitlets import Instance, Type @@ -9,10 +10,10 @@ from ..client import KernelClient, reqrep -def wrapped(meth, channel): +def wrapped(meth: t.Callable, channel: str) -> t.Callable: """Wrap a method on a channel and handle replies.""" - def _(self, *args, **kwargs): + def _(self: AsyncKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any: reply = kwargs.pop("reply", False) timeout = kwargs.pop("timeout", None) msg_id = meth(self, *args, **kwargs) diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 271526221..bff55f5a7 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -4,6 +4,8 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import typing as t + from traitlets import Type from ..channels import HBChannel, ZMQSocketChannel @@ -11,10 +13,10 @@ from ..utils import run_sync -def wrapped(meth, channel): +def wrapped(meth: t.Callable, channel: str) -> t.Callable: """Wrap a method on a channel and handle replies.""" - def _(self, *args, **kwargs): + def _(self: BlockingKernelClient, *args: t.Any, **kwargs: t.Any) -> t.Any: reply = kwargs.pop("reply", False) timeout = kwargs.pop("timeout", None) msg_id = meth(self, *args, **kwargs) diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index 7d3618614..5b2eedad8 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -54,7 +54,7 @@ def __init__( context: t.Optional[zmq.Context] = None, session: t.Optional[Session] = None, address: t.Union[t.Tuple[str, int], str] = "", - ): + ) -> None: """Create the heartbeat monitor thread. Parameters diff --git a/jupyter_client/channelsabc.py b/jupyter_client/channelsabc.py index 1bfe7922c..af053dfa3 100644 --- a/jupyter_client/channelsabc.py +++ b/jupyter_client/channelsabc.py @@ -8,17 +8,17 @@ class ChannelABC(metaclass=abc.ABCMeta): """A base class for all channel ABCs.""" @abc.abstractmethod - def start(self): + def start(self) -> None: """Start the channel.""" pass @abc.abstractmethod - def stop(self): + def stop(self) -> None: """Stop the channel.""" pass @abc.abstractmethod - def is_alive(self): + def is_alive(self) -> bool: """Test whether the channel is alive.""" pass @@ -32,20 +32,20 @@ class HBChannelABC(ChannelABC): """ @abc.abstractproperty - def time_to_dead(self): + def time_to_dead(self) -> float: pass @abc.abstractmethod - def pause(self): + def pause(self) -> None: """Pause the heartbeat channel.""" pass @abc.abstractmethod - def unpause(self): + def unpause(self) -> None: """Unpause the heartbeat channel.""" pass @abc.abstractmethod - def is_beating(self): + def is_beating(self) -> bool: """Test whether the channel is beating.""" pass diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 91adab679..aa353ac28 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -113,7 +113,7 @@ def _context_default(self) -> zmq.Context: # flag for whether execute requests should be allowed to call raw_input: allow_stdin: bool = True - def __del__(self): + def __del__(self) -> None: """Handle garbage collection. Destroy context if applicable.""" if ( self._created_context @@ -511,7 +511,7 @@ async def _async_execute_interactive( if output_hook is None and "IPython" in sys.modules: from IPython import get_ipython - ip = get_ipython() + ip = get_ipython() # type:ignore[no-untyped-call] in_kernel = getattr(ip, "kernel", False) if in_kernel: output_hook = partial( diff --git a/jupyter_client/clientabc.py b/jupyter_client/clientabc.py index 3623b833f..cc14bda67 100644 --- a/jupyter_client/clientabc.py +++ b/jupyter_client/clientabc.py @@ -9,6 +9,10 @@ # Imports # ----------------------------------------------------------------------------- import abc +from typing import TYPE_CHECKING, Any, Type + +if TYPE_CHECKING: + from .channelsabc import ChannelABC # ----------------------------------------------------------------------------- # Main kernel client class @@ -24,27 +28,27 @@ class KernelClientABC(metaclass=abc.ABCMeta): """ @abc.abstractproperty - def kernel(self): + def kernel(self) -> Any: pass @abc.abstractproperty - def shell_channel_class(self): + def shell_channel_class(self) -> Type[ChannelABC]: pass @abc.abstractproperty - def iopub_channel_class(self): + def iopub_channel_class(self) -> Type[ChannelABC]: pass @abc.abstractproperty - def hb_channel_class(self): + def hb_channel_class(self) -> Type[ChannelABC]: pass @abc.abstractproperty - def stdin_channel_class(self): + def stdin_channel_class(self) -> Type[ChannelABC]: pass @abc.abstractproperty - def control_channel_class(self): + def control_channel_class(self) -> Type[ChannelABC]: pass # -------------------------------------------------------------------------- @@ -52,36 +56,43 @@ def control_channel_class(self): # -------------------------------------------------------------------------- @abc.abstractmethod - def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): + def start_channels( + self, + shell: bool = True, + iopub: bool = True, + stdin: bool = True, + hb: bool = True, + control: bool = True, + ) -> None: """Start the channels for the client.""" pass @abc.abstractmethod - def stop_channels(self): + def stop_channels(self) -> None: """Stop the channels for the client.""" pass @abc.abstractproperty - def channels_running(self): + def channels_running(self) -> bool: """Get whether the channels are running.""" pass @abc.abstractproperty - def shell_channel(self): + def shell_channel(self) -> ChannelABC: pass @abc.abstractproperty - def iopub_channel(self): + def iopub_channel(self) -> ChannelABC: pass @abc.abstractproperty - def stdin_channel(self): + def stdin_channel(self) -> ChannelABC: pass @abc.abstractproperty - def hb_channel(self): + def hb_channel(self) -> ChannelABC: pass @abc.abstractproperty - def control_channel(self): + def control_channel(self) -> ChannelABC: pass diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index 74e467371..2564b7693 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -14,7 +14,7 @@ import tempfile import warnings from getpass import getpass -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast import zmq from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write @@ -24,6 +24,11 @@ from .localinterfaces import localhost from .utils import _filefind +if TYPE_CHECKING: + from jupyter_client import BlockingKernelClient + + from .session import Session + # Define custom type for kernel connection info KernelConnectionInfo = Dict[str, Union[int, str, bytes]] @@ -312,7 +317,7 @@ class ConnectionFileMixin(LoggingConfigurable): data_dir: Union[str, Unicode] = Unicode() - def _data_dir_default(self): + def _data_dir_default(self) -> str: return jupyter_data_dir() # The addresses for the communication channels @@ -351,7 +356,7 @@ def _ip_default(self) -> str: return localhost() @observe("ip") - def _ip_changed(self, change): + def _ip_changed(self, change: Any) -> None: if change["new"] == "*": self.ip = "0.0.0.0" # noqa @@ -373,7 +378,7 @@ def ports(self) -> List[int]: # The Session to use for communication with the kernel. session = Instance("jupyter_client.session.Session") - def _session_default(self): + def _session_default(self) -> Session: from .session import Session return Session(parent=self) @@ -423,7 +428,7 @@ def get_connection_info(self, session: bool = False) -> KernelConnectionInfo: # factory for blocking clients blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient") - def blocking_client(self): + def blocking_client(self) -> BlockingKernelClient: """Make a blocking client connected to my kernel""" info = self.get_connection_info() bc = self.blocking_class(parent=self) # type:ignore[operator] diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index f49a25b54..e96daecb5 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -370,7 +370,7 @@ def initialize(self, argv: object = None) -> None: class IPythonConsoleApp(JupyterConsoleApp): """An app to manage an ipython console.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: """Initialize the app.""" warnings.warn("IPythonConsoleApp is deprecated. Use JupyterConsoleApp", stacklevel=2) super().__init__(*args, **kwargs) diff --git a/jupyter_client/ioloop/manager.py b/jupyter_client/ioloop/manager.py index 5b5c3dc45..3c44e6123 100644 --- a/jupyter_client/ioloop/manager.py +++ b/jupyter_client/ioloop/manager.py @@ -12,10 +12,10 @@ from .restarter import AsyncIOLoopKernelRestarter, IOLoopKernelRestarter -def as_zmqstream(f): +def as_zmqstream(f: t.Any) -> t.Callable: """Convert a socket to a zmq stream.""" - def wrapped(self, *args, **kwargs): + def wrapped(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: save_socket_class = None # zmqstreams only support sync sockets if self.context._socket_class is not zmq.Socket: @@ -37,7 +37,7 @@ class IOLoopKernelManager(KernelManager): loop = Instance("tornado.ioloop.IOLoop") - def _loop_default(self): + def _loop_default(self) -> ioloop.IOLoop: return ioloop.IOLoop.current() restarter_class = Type( @@ -52,7 +52,7 @@ def _loop_default(self): ) _restarter: t.Any = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True) - def start_restarter(self): + def start_restarter(self) -> None: """Start the restarter.""" if self.autorestart and self.has_kernel: if self._restarter is None: @@ -61,7 +61,7 @@ def start_restarter(self): ) self._restarter.start() - def stop_restarter(self): + def stop_restarter(self) -> None: """Stop the restarter.""" if self.autorestart and self._restarter is not None: self._restarter.stop() @@ -78,7 +78,7 @@ class AsyncIOLoopKernelManager(AsyncKernelManager): loop = Instance("tornado.ioloop.IOLoop") - def _loop_default(self): + def _loop_default(self) -> ioloop.IOLoop: return ioloop.IOLoop.current() restarter_class = Type( @@ -95,7 +95,7 @@ def _loop_default(self): "jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True ) - def start_restarter(self): + def start_restarter(self) -> None: """Start the restarter.""" if self.autorestart and self.has_kernel: if self._restarter is None: @@ -104,7 +104,7 @@ def start_restarter(self): ) self._restarter.start() - def stop_restarter(self): + def stop_restarter(self) -> None: """Stop the restarter.""" if self.autorestart and self._restarter is not None: self._restarter.stop() diff --git a/jupyter_client/ioloop/restarter.py b/jupyter_client/ioloop/restarter.py index d0c70396a..64b508402 100644 --- a/jupyter_client/ioloop/restarter.py +++ b/jupyter_client/ioloop/restarter.py @@ -7,6 +7,7 @@ # Distributed under the terms of the Modified BSD License. import time import warnings +from typing import Any from traitlets import Instance @@ -18,7 +19,7 @@ class IOLoopKernelRestarter(KernelRestarter): loop = Instance("tornado.ioloop.IOLoop") - def _loop_default(self): + def _loop_default(self) -> Any: warnings.warn( "IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2", DeprecationWarning, @@ -30,7 +31,7 @@ def _loop_default(self): _pcallback = None - def start(self): + def start(self) -> None: """Start the polling of the kernel.""" if self._pcallback is None: from tornado.ioloop import PeriodicCallback @@ -41,7 +42,7 @@ def start(self): ) self._pcallback.start() - def stop(self): + def stop(self) -> None: """Stop the kernel polling.""" if self._pcallback is not None: self._pcallback.stop() @@ -51,7 +52,7 @@ def stop(self): class AsyncIOLoopKernelRestarter(IOLoopKernelRestarter): """An async io loop kernel restarter.""" - async def poll(self): + async def poll(self) -> None: # type:ignore[override] """Poll the kernel.""" if self.debug: self.log.debug("Polling kernel...") diff --git a/jupyter_client/jsonutil.py b/jupyter_client/jsonutil.py index db46d1b11..36730513b 100644 --- a/jupyter_client/jsonutil.py +++ b/jupyter_client/jsonutil.py @@ -9,7 +9,7 @@ from binascii import b2a_base64 from collections.abc import Iterable from datetime import datetime -from typing import Optional, Union +from typing import Any, Optional, Union from dateutil.parser import parse as _dateutil_parse from dateutil.tz import tzlocal @@ -67,7 +67,7 @@ def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]: return s -def extract_dates(obj): +def extract_dates(obj: Any) -> Any: """extract ISO8601 dates from unpacked JSON""" if isinstance(obj, dict): new_obj = {} # don't clobber @@ -81,7 +81,7 @@ def extract_dates(obj): return obj -def squash_dates(obj): +def squash_dates(obj: Any) -> Any: """squash datetime objects into ISO8601 strings""" if isinstance(obj, dict): obj = dict(obj) # don't clobber @@ -94,7 +94,7 @@ def squash_dates(obj): return obj -def date_default(obj): +def date_default(obj: Any) -> Any: """DEPRECATED: Use jupyter_client.jsonutil.json_default""" warnings.warn( "date_default is deprecated since jupyter_client 7.0.0." @@ -104,7 +104,7 @@ def date_default(obj): return json_default(obj) -def json_default(obj): +def json_default(obj: Any) -> Any: """default function for packing objects in JSON.""" if isinstance(obj, datetime): obj = _ensure_tzinfo(obj) @@ -128,7 +128,7 @@ def json_default(obj): # Copy of the old ipykernel's json_clean # This is temporary, it should be removed when we deprecate support for # non-valid JSON messages -def json_clean(obj): +def json_clean(obj: Any) -> Any: # types that are 'atomic' and ok in json as-is. atomic_ok = (str, type(None)) diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index b66e15422..5d43c64ed 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -1,6 +1,7 @@ """An application to launch a kernel by name in a local subprocess.""" import os import signal +import typing as t import uuid from jupyter_core.application import JupyterApp, base_flags @@ -30,7 +31,7 @@ class KernelApp(JupyterApp): config=True ) - def initialize(self, argv=None): + def initialize(self, argv: t.Union[str, t.Sequence[str], None] = None) -> None: """Initialize the application.""" super().initialize(argv) @@ -48,7 +49,7 @@ def setup_signals(self) -> None: if os.name == "nt": return - def shutdown_handler(signo, frame): + def shutdown_handler(signo: int, frame: t.Any) -> None: self.loop.add_callback_from_signal(self.shutdown, signo) for sig in [signal.SIGTERM, signal.SIGINT]: diff --git a/jupyter_client/kernelspec.py b/jupyter_client/kernelspec.py index 26c36865b..ff2185a84 100644 --- a/jupyter_client/kernelspec.py +++ b/jupyter_client/kernelspec.py @@ -1,10 +1,13 @@ """Tools for managing kernel specs""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import json import os import re import shutil +import typing as t import warnings from jupyter_core.paths import SYSTEM_JUPYTER_PATH, jupyter_data_dir, jupyter_path @@ -32,7 +35,7 @@ class KernelSpec(HasTraits): metadata = Dict() @classmethod - def from_resource_dir(cls, resource_dir): + def from_resource_dir(cls: type[KernelSpec], resource_dir: str) -> KernelSpec: """Create a KernelSpec object by reading kernel.json Pass the path to the *directory* containing kernel.json. @@ -42,7 +45,7 @@ def from_resource_dir(cls, resource_dir): kernel_dict = json.load(f) return cls(resource_dir=resource_dir, **kernel_dict) - def to_dict(self): + def to_dict(self) -> dict[str, t.Any]: """Convert the kernel spec to a dict.""" d = { "argv": self.argv, @@ -55,7 +58,7 @@ def to_dict(self): return d - def to_json(self): + def to_json(self) -> str: """Serialise this kernelspec to a JSON object. Returns a string. @@ -66,7 +69,7 @@ def to_json(self): _kernel_name_pat = re.compile(r"^[a-z0-9._\-]+$", re.IGNORECASE) -def _is_valid_kernel_name(name): +def _is_valid_kernel_name(name: str) -> t.Any: """Check that a kernel name is valid.""" # quote is not unicode-safe on Python 2 return _kernel_name_pat.match(name) @@ -78,12 +81,12 @@ def _is_valid_kernel_name(name): ) -def _is_kernel_dir(path): +def _is_kernel_dir(path: str) -> bool: """Is ``path`` a kernel directory?""" return os.path.isdir(path) and os.path.isfile(pjoin(path, "kernel.json")) -def _list_kernels_in(dir): +def _list_kernels_in(dir: str | None) -> dict[str, str]: """Return a mapping of kernel names to resource directories from dir. If dir is None or does not exist, returns an empty dict. @@ -108,11 +111,11 @@ def _list_kernels_in(dir): class NoSuchKernel(KeyError): # noqa """An error raised when there is no kernel of a give name.""" - def __init__(self, name): + def __init__(self, name: str) -> None: """Initialize the error.""" self.name = name - def __str__(self): + def __str__(self) -> str: return f"No such kernel named {self.name}" @@ -137,12 +140,12 @@ class KernelSpecManager(LoggingConfigurable): data_dir = Unicode() - def _data_dir_default(self): + def _data_dir_default(self) -> str: return jupyter_data_dir() user_kernel_dir = Unicode() - def _user_kernel_dir_default(self): + def _user_kernel_dir_default(self) -> str: return pjoin(self.data_dir, "kernels") whitelist = Set( @@ -168,7 +171,7 @@ def _user_kernel_dir_default(self): # Method copied from # https://github.com/jupyterhub/jupyterhub/blob/d1a85e53dccfc7b1dd81b0c1985d158cc6b61820/jupyterhub/auth.py#L143-L161 @observe(*list(_deprecated_aliases)) - def _deprecated_trait(self, change): + def _deprecated_trait(self, change: t.Any) -> None: """observer for deprecated traits""" old_attr = change.name new_attr, version = self._deprecated_aliases[old_attr] @@ -183,7 +186,7 @@ def _deprecated_trait(self, change): ) setattr(self, new_attr, change.new) - def _kernel_dirs_default(self): + def _kernel_dirs_default(self) -> list[str]: dirs = jupyter_path("kernels") # At some point, we should stop adding .ipython/kernels to the path, # but the cost to keeping it is very small. @@ -196,7 +199,7 @@ def _kernel_dirs_default(self): pass return dirs - def find_kernel_specs(self): + def find_kernel_specs(self) -> dict[str, str]: """Returns a dict mapping kernel names to resource directories.""" d = {} for kernel_dir in self.kernel_dirs: @@ -225,7 +228,7 @@ def find_kernel_specs(self): return d # TODO: Caching? - def _get_kernel_spec_by_name(self, kernel_name, resource_dir): + def _get_kernel_spec_by_name(self, kernel_name: str, resource_dir: str) -> KernelSpec: """Returns a :class:`KernelSpec` instance for a given kernel_name and resource_dir. """ @@ -238,7 +241,8 @@ def _get_kernel_spec_by_name(self, kernel_name, resource_dir): pass else: if resource_dir == RESOURCES: - kspec = self.kernel_spec_class(resource_dir=resource_dir, **get_kernel_dict()) + kdict = get_kernel_dict() # type:ignore[no-untyped-call] + kspec = self.kernel_spec_class(resource_dir=resource_dir, **kdict) if not kspec: kspec = self.kernel_spec_class.from_resource_dir(resource_dir) @@ -247,7 +251,7 @@ def _get_kernel_spec_by_name(self, kernel_name, resource_dir): return kspec - def _find_spec_directory(self, kernel_name): + def _find_spec_directory(self, kernel_name: str) -> str | None: """Find the resource directory of a named kernel spec""" for kernel_dir in [kd for kd in self.kernel_dirs if os.path.isdir(kd)]: files = os.listdir(kernel_dir) @@ -263,8 +267,9 @@ def _find_spec_directory(self, kernel_name): pass else: return RESOURCES + return None - def get_kernel_spec(self, kernel_name): + def get_kernel_spec(self, kernel_name: str) -> KernelSpec: """Returns a :class:`KernelSpec` instance for the given kernel_name. Raises :exc:`NoSuchKernel` if the given kernel name is not found. @@ -281,7 +286,7 @@ def get_kernel_spec(self, kernel_name): return self._get_kernel_spec_by_name(kernel_name, resource_dir) - def get_all_specs(self): + def get_all_specs(self) -> dict[str, t.Any]: """Returns a dict mapping kernel names to kernelspecs. Returns a dict of the form:: @@ -313,7 +318,7 @@ def get_all_specs(self): self.log.warning("Error loading kernelspec %r", kname, exc_info=True) return res - def remove_kernel_spec(self, name): + def remove_kernel_spec(self, name: str) -> str: """Remove a kernel spec directory by name. Returns the path that was deleted. @@ -332,7 +337,9 @@ def remove_kernel_spec(self, name): shutil.rmtree(spec_dir) return spec_dir - def _get_destination_dir(self, kernel_name, user=False, prefix=None): + def _get_destination_dir( + self, kernel_name: str, user: bool = False, prefix: str | None = None + ) -> str: if user: return os.path.join(self.user_kernel_dir, kernel_name) elif prefix: @@ -341,8 +348,13 @@ def _get_destination_dir(self, kernel_name, user=False, prefix=None): return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name) def install_kernel_spec( - self, source_dir, kernel_name=None, user=False, replace=None, prefix=None - ): + self, + source_dir: str, + kernel_name: str | None = None, + user: bool = False, + replace: bool | None = None, + prefix: str | None = None, + ) -> str: """Install a kernel spec by copying its directory. If ``kernel_name`` is not given, the basename of ``source_dir`` will @@ -395,7 +407,7 @@ def install_kernel_spec( self.log.info("Installed kernelspec %s in %s", kernel_name, destination) return destination - def install_native_kernel_spec(self, user=False): + def install_native_kernel_spec(self, user: bool = False) -> None: """DEPRECATED: Use ipykernel.kernelspec.install""" warnings.warn( "install_native_kernel_spec is deprecated. Use ipykernel.kernelspec import install.", @@ -403,15 +415,15 @@ def install_native_kernel_spec(self, user=False): ) from ipykernel.kernelspec import install - install(self, user=user) + install(self, user=user) # type:ignore[no-untyped-call] -def find_kernel_specs(): +def find_kernel_specs() -> dict[str, str]: """Returns a dict mapping kernel names to resource directories.""" return KernelSpecManager().find_kernel_specs() -def get_kernel_spec(kernel_name): +def get_kernel_spec(kernel_name: str) -> KernelSpec: """Returns a :class:`KernelSpec` instance for the given kernel_name. Raises KeyError if the given kernel name is not found. @@ -419,7 +431,13 @@ def get_kernel_spec(kernel_name): return KernelSpecManager().get_kernel_spec(kernel_name) -def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False, prefix=None): +def install_kernel_spec( + source_dir: str, + kernel_name: str | None = None, + user: bool = False, + replace: bool | None = False, + prefix: str | None = None, +) -> str: """Install a kernel spec in a given directory.""" return KernelSpecManager().install_kernel_spec(source_dir, kernel_name, user, replace, prefix) @@ -427,9 +445,9 @@ def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False, install_kernel_spec.__doc__ = KernelSpecManager.install_kernel_spec.__doc__ -def install_native_kernel_spec(user=False): +def install_native_kernel_spec(user: bool = False) -> None: """Install the native kernel spec.""" - return KernelSpecManager().install_native_kernel_spec(user=user) + KernelSpecManager().install_native_kernel_spec(user=user) install_native_kernel_spec.__doc__ = KernelSpecManager.install_native_kernel_spec.__doc__ diff --git a/jupyter_client/kernelspecapp.py b/jupyter_client/kernelspecapp.py index eb0ce8a31..3465170e6 100644 --- a/jupyter_client/kernelspecapp.py +++ b/jupyter_client/kernelspecapp.py @@ -1,10 +1,13 @@ """Apps for managing kernel specs.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import errno import json import os.path import sys +import typing as t from jupyter_core.application import JupyterApp, base_aliases, base_flags from traitlets import Bool, Dict, Instance, List, Unicode @@ -35,21 +38,21 @@ class ListKernelSpecs(JupyterApp): "debug": base_flags["debug"], } - def _kernel_spec_manager_default(self): + def _kernel_spec_manager_default(self) -> KernelSpecManager: return KernelSpecManager(parent=self, data_dir=self.data_dir) - def start(self): + def start(self) -> dict[str, t.Any] | None: # type:ignore[override] """Start the application.""" paths = self.kernel_spec_manager.find_kernel_specs() specs = self.kernel_spec_manager.get_all_specs() if not self.json_output: if not specs: print("No kernels available") - return + return None # pad to width of longest kernel name name_len = len(sorted(paths, key=lambda name: len(name))[-1]) - def path_key(item): + def path_key(item: t.Any) -> t.Any: """sort key function for Jupyter path priority""" path = item[1] for idx, prefix in enumerate(self.jupyter_path): @@ -83,13 +86,13 @@ class InstallKernelSpec(JupyterApp): usage = "jupyter kernelspec install SOURCE_DIR [--options]" kernel_spec_manager = Instance(KernelSpecManager) - def _kernel_spec_manager_default(self): + def _kernel_spec_manager_default(self) -> KernelSpecManager: return KernelSpecManager(data_dir=self.data_dir) sourcedir = Unicode() kernel_name = Unicode("", config=True, help="Install the kernel spec with this name") - def _kernel_name_default(self): + def _kernel_name_default(self) -> str: return os.path.basename(self.sourcedir) user = Bool( @@ -131,7 +134,7 @@ def _kernel_name_default(self): "debug": base_flags["debug"], } - def parse_command_line(self, argv): + def parse_command_line(self, argv: None | list[str]) -> None: # type:ignore[override] """Parse the command line args.""" super().parse_command_line(argv) # accept positional arg as profile name @@ -141,7 +144,7 @@ def parse_command_line(self, argv): print("No source directory specified.", file=sys.stderr) self.exit(1) - def start(self): + def start(self) -> None: """Start the application.""" if self.user and self.prefix: self.exit("Can't specify both user and prefix. Please choose one or the other.") @@ -177,7 +180,7 @@ class RemoveKernelSpec(JupyterApp): kernel_spec_manager = Instance(KernelSpecManager) - def _kernel_spec_manager_default(self): + def _kernel_spec_manager_default(self) -> KernelSpecManager: return KernelSpecManager(data_dir=self.data_dir, parent=self) flags = { @@ -185,7 +188,7 @@ def _kernel_spec_manager_default(self): } flags.update(JupyterApp.flags) - def parse_command_line(self, argv): + def parse_command_line(self, argv: list[str] | None) -> None: # type:ignore[override] """Parse the command line args.""" super().parse_command_line(argv) # accept positional arg as profile name @@ -194,7 +197,7 @@ def parse_command_line(self, argv): else: self.exit("No kernelspec specified.") - def start(self): + def start(self) -> None: """Start the application.""" self.kernel_spec_manager.ensure_native_kernel = False spec_paths = self.kernel_spec_manager.find_kernel_specs() @@ -231,7 +234,7 @@ class InstallNativeKernelSpec(JupyterApp): description = """[DEPRECATED] Install the IPython kernel spec directory for this Python.""" kernel_spec_manager = Instance(KernelSpecManager) - def _kernel_spec_manager_default(self): # pragma: no cover + def _kernel_spec_manager_default(self) -> KernelSpecManager: # pragma: no cover return KernelSpecManager(data_dir=self.data_dir) user = Bool( @@ -251,7 +254,7 @@ def _kernel_spec_manager_default(self): # pragma: no cover "debug": base_flags["debug"], } - def start(self): # pragma: no cover + def start(self) -> None: # pragma: no cover """Start the application.""" self.log.warning( "`jupyter kernelspec install-self` is DEPRECATED as of 4.0." @@ -263,7 +266,9 @@ def start(self): # pragma: no cover print("ipykernel not available, can't install its spec.", file=sys.stderr) self.exit(1) try: - kernelspec.install(self.kernel_spec_manager, user=self.user) + kernelspec.install( + self.kernel_spec_manager, user=self.user + ) # type:ignore[no-untyped-call] except OSError as e: if e.errno == errno.EACCES: print(e, file=sys.stderr) @@ -282,7 +287,7 @@ class ListProvisioners(JupyterApp): version = __version__ description = """List available provisioners for use in kernel specifications.""" - def start(self): + def start(self) -> None: """Start the application.""" kfp = KernelProvisionerFactory.instance(parent=self) print("Available kernel provisioners:") @@ -322,7 +327,7 @@ class KernelSpecApp(Application): aliases = {} flags = {} - def start(self): + def start(self) -> None: """Start the application.""" if self.subapp is None: print("No subcommand specified. Must specify one of: %s" % list(self.subcommands)) diff --git a/jupyter_client/localinterfaces.py b/jupyter_client/localinterfaces.py index 4b9143bdd..ca684a6ba 100644 --- a/jupyter_client/localinterfaces.py +++ b/jupyter_client/localinterfaces.py @@ -1,21 +1,23 @@ """Utilities for identifying local IP addresses.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import os import re import socket import subprocess from subprocess import PIPE, Popen -from typing import Iterable, List +from typing import Any, Callable, Iterable, Sequence from warnings import warn -LOCAL_IPS: List = [] -PUBLIC_IPS: List = [] +LOCAL_IPS: list = [] +PUBLIC_IPS: list = [] LOCALHOST: str = "" -def _uniq_stable(elems: Iterable) -> List: +def _uniq_stable(elems: Iterable) -> list: """uniq_stable(elems) -> list Return from an iterable, a list of all the unique elements in the input, @@ -30,7 +32,7 @@ def _uniq_stable(elems: Iterable) -> List: return value -def _get_output(cmd): +def _get_output(cmd: str | Sequence[str]) -> str: """Get output of a command, raising IOError if it fails""" startupinfo = None if os.name == "nt": @@ -44,24 +46,24 @@ def _get_output(cmd): return stdout.decode("utf8", "replace") -def _only_once(f): +def _only_once(f: Callable) -> Callable: """decorator to only run a function once""" - f.called = False + f.called = False # type:ignore[attr-defined] - def wrapped(**kwargs): - if f.called: + def wrapped(**kwargs: Any) -> Any: + if f.called: # type:ignore[attr-defined] return ret = f(**kwargs) - f.called = True + f.called = True # type:ignore[attr-defined] return ret return wrapped -def _requires_ips(f): +def _requires_ips(f: Callable) -> Callable: """decorator to ensure load_ips has been run before f""" - def ips_loaded(*args, **kwargs): + def ips_loaded(*args: Any, **kwargs: Any) -> Any: _load_ips() return f(*args, **kwargs) @@ -73,7 +75,7 @@ class NoIPAddresses(Exception): # noqa pass -def _populate_from_list(addrs): +def _populate_from_list(addrs: Sequence[str] | None) -> None: """populate local and public IPs from flat list of all IPs""" if not addrs: raise NoIPAddresses @@ -102,7 +104,7 @@ def _populate_from_list(addrs): _ifconfig_ipv4_pat = re.compile(r"inet\b.*?(\d+\.\d+\.\d+\.\d+)", re.IGNORECASE) -def _load_ips_ifconfig(): +def _load_ips_ifconfig() -> None: """load ip addresses from `ifconfig` output (posix)""" try: @@ -120,7 +122,7 @@ def _load_ips_ifconfig(): _populate_from_list(addrs) -def _load_ips_ip(): +def _load_ips_ip() -> None: """load ip addresses from `ip addr` output (Linux)""" out = _get_output(["ip", "-f", "inet", "addr"]) @@ -136,7 +138,7 @@ def _load_ips_ip(): _ipconfig_ipv4_pat = re.compile(r"ipv4.*?(\d+\.\d+\.\d+\.\d+)$", re.IGNORECASE) -def _load_ips_ipconfig(): +def _load_ips_ipconfig() -> None: """load ip addresses from `ipconfig` output (Windows)""" out = _get_output("ipconfig") @@ -149,7 +151,7 @@ def _load_ips_ipconfig(): _populate_from_list(addrs) -def _load_ips_netifaces(): +def _load_ips_netifaces() -> None: """load ip addresses with netifaces""" import netifaces # type: ignore[import-not-found] @@ -179,7 +181,7 @@ def _load_ips_netifaces(): PUBLIC_IPS[:] = _uniq_stable(public_ips) -def _load_ips_gethostbyname(): +def _load_ips_gethostbyname() -> None: """load ip addresses with socket.gethostbyname_ex This can be slow. @@ -211,7 +213,7 @@ def _load_ips_gethostbyname(): LOCALHOST = LOCAL_IPS[0] -def _load_ips_dumb(): +def _load_ips_dumb() -> None: """Fallback in case of unexpected failure""" global LOCALHOST LOCALHOST = "127.0.0.1" @@ -220,7 +222,7 @@ def _load_ips_dumb(): @_only_once -def _load_ips(suppress_exceptions=True): +def _load_ips(suppress_exceptions: bool = True) -> None: """load the IPs that point to this machine This function will only ever be called once. @@ -266,30 +268,30 @@ def _load_ips(suppress_exceptions=True): @_requires_ips -def local_ips(): +def local_ips() -> list[str]: """return the IP addresses that point to this machine""" return LOCAL_IPS @_requires_ips -def public_ips(): +def public_ips() -> list[str]: """return the IP addresses for this machine that are visible to other machines""" return PUBLIC_IPS @_requires_ips -def localhost(): +def localhost() -> str: """return ip for localhost (almost always 127.0.0.1)""" return LOCALHOST @_requires_ips -def is_local_ip(ip): +def is_local_ip(ip: str) -> bool: """does `ip` point to this machine?""" return ip in LOCAL_IPS @_requires_ips -def is_public_ip(ip): +def is_public_ip(ip: str) -> bool: """is `ip` a publicly visible address?""" return ip in PUBLIC_IPS diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 3dff5433b..b9168db0f 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -76,7 +76,7 @@ def in_pending_state(method: F) -> F: @t.no_type_check @functools.wraps(method) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: """Create a future for the decorated method.""" if self._attempted_start or not self._ready: self._ready = _get_future() @@ -104,7 +104,7 @@ class KernelManager(ConnectionFileMixin): _ready: t.Optional[t.Union[Future, CFuture]] - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: """Initialize a kernel manager.""" self._owns_kernel = kwargs.pop("owns_kernel", True) super().__init__(**kwargs) @@ -304,7 +304,7 @@ def format_kernel_cmd(self, extra_arguments: t.Optional[t.List[str]] = None) -> pat = re.compile(r"\{([A-Za-z0-9_]+)\}") - def from_ns(match): + def from_ns(match: t.Any) -> t.Any: """Get the key out of ns if it's there, otherwise no change.""" return ns.get(match.group(1), match.group()) diff --git a/jupyter_client/managerabc.py b/jupyter_client/managerabc.py index 8e33069cb..c74ea1dce 100644 --- a/jupyter_client/managerabc.py +++ b/jupyter_client/managerabc.py @@ -2,6 +2,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import abc +from typing import Any class KernelManagerABC(metaclass=abc.ABCMeta): @@ -9,11 +10,11 @@ class KernelManagerABC(metaclass=abc.ABCMeta): The docstrings for this class can be found in the base implementation: - `jupyter_client.kernelmanager.KernelManager` + `jupyter_client.manager.KernelManager` """ @abc.abstractproperty - def kernel(self): + def kernel(self) -> Any: pass # -------------------------------------------------------------------------- @@ -21,35 +22,35 @@ def kernel(self): # -------------------------------------------------------------------------- @abc.abstractmethod - def start_kernel(self, **kw): + def start_kernel(self, **kw: Any) -> None: """Start the kernel.""" pass @abc.abstractmethod - def shutdown_kernel(self, now=False, restart=False): + def shutdown_kernel(self, now: bool = False, restart: bool = False) -> None: """Shut down the kernel.""" pass @abc.abstractmethod - def restart_kernel(self, now=False, **kw): + def restart_kernel(self, now: bool = False, **kw: Any) -> None: """Restart the kernel.""" pass @abc.abstractproperty - def has_kernel(self): + def has_kernel(self) -> bool: pass @abc.abstractmethod - def interrupt_kernel(self): + def interrupt_kernel(self) -> None: """Interrupt the kernel.""" pass @abc.abstractmethod - def signal_kernel(self, signum): + def signal_kernel(self, signum: int) -> None: """Send a signal to the kernel.""" pass @abc.abstractmethod - def is_alive(self): + def is_alive(self) -> bool: """Test whether the kernel is alive.""" pass diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 2ebd0e9dc..95b63d512 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -1,6 +1,8 @@ """A kernel manager for multiple kernels""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import asyncio import json import os @@ -31,7 +33,7 @@ def kernel_method(f: t.Callable) -> t.Callable: @wraps(f) def wrapped( self: t.Any, kernel_id: str, *args: t.Any, **kwargs: t.Any - ) -> t.Union[t.Callable, t.Awaitable]: + ) -> t.Callable | t.Awaitable: # get the kernel km = self.get_kernel(kernel_id) method = getattr(km, f.__name__) @@ -63,13 +65,13 @@ class MultiKernelManager(LoggingConfigurable): ).tag(config=True) @observe("kernel_manager_class") - def _kernel_manager_class_changed(self, change): + def _kernel_manager_class_changed(self, change: t.Any) -> None: self.kernel_manager_factory = self._create_kernel_manager_factory() kernel_manager_factory = Any(help="this is kernel_manager_class after import") @default("kernel_manager_factory") - def _kernel_manager_factory_default(self): + def _kernel_manager_factory_default(self) -> t.Callable: return self._create_kernel_manager_factory() def _create_kernel_manager_factory(self) -> t.Callable: @@ -98,7 +100,7 @@ def create_kernel_manager(*args: t.Any, **kwargs: t.Any) -> KernelManager: _pending_kernels = Dict() @property - def _starting_kernels(self): + def _starting_kernels(self) -> dict: """A shim for backwards compatibility.""" return self._pending_kernels @@ -112,11 +114,11 @@ def _context_default(self) -> zmq.Context: _kernels = Dict() - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - self.kernel_id_to_connection_file = {} + self.kernel_id_to_connection_file: dict[str, Path] = {} - def __del__(self): + def __del__(self) -> None: """Handle garbage collection. Destroy context if applicable.""" if self._created_context and self.context and not self.context.closed: if self.log: @@ -129,7 +131,7 @@ def __del__(self): else: super_del() - def list_kernel_ids(self) -> t.List[str]: + def list_kernel_ids(self) -> list[str]: """Return a list of the kernel ids of the active kernels.""" if self.external_connection_dir is not None: external_connection_dir = Path(self.external_connection_dir) @@ -188,8 +190,8 @@ def __contains__(self, kernel_id: str) -> bool: return kernel_id in self._kernels def pre_start_kernel( - self, kernel_name: t.Optional[str], kwargs: t.Any - ) -> t.Tuple[KernelManager, str, str]: + self, kernel_name: str | None, kwargs: t.Any + ) -> tuple[KernelManager, str, str]: # kwargs should be mutable, passing it as a dict argument. kernel_id = kwargs.pop("kernel_id", self.new_kernel_id(**kwargs)) if kernel_id in self: @@ -232,15 +234,13 @@ async def _remove_kernel_when_ready( except Exception as e: self.log.exception(e) - def _using_pending_kernels(self): + def _using_pending_kernels(self) -> bool: """Returns a boolean; a clearer method for determining if this multikernelmanager is using pending kernels or not """ return getattr(self, 'use_pending_kernels', False) - async def _async_start_kernel( - self, *, kernel_name: t.Optional[str] = None, **kwargs: t.Any - ) -> str: + async def _async_start_kernel(self, *, kernel_name: str | None = None, **kwargs: t.Any) -> str: """Start a new kernel. The caller can pick a kernel_id by passing one in as a keyword arg, @@ -278,8 +278,8 @@ async def _async_start_kernel( async def _async_shutdown_kernel( self, kernel_id: str, - now: t.Optional[bool] = False, - restart: t.Optional[bool] = False, + now: bool | None = False, + restart: bool | None = False, ) -> None: """Shutdown a kernel by its kernel uuid. @@ -323,15 +323,15 @@ async def _async_shutdown_kernel( shutdown_kernel = run_sync(_async_shutdown_kernel) @kernel_method - def request_shutdown(self, kernel_id: str, restart: t.Optional[bool] = False) -> None: + def request_shutdown(self, kernel_id: str, restart: bool | None = False) -> None: """Ask a kernel to shut down by its kernel uuid""" @kernel_method def finish_shutdown( self, kernel_id: str, - waittime: t.Optional[float] = None, - pollinterval: t.Optional[float] = 0.1, + waittime: float | None = None, + pollinterval: float | None = 0.1, ) -> None: """Wait for a kernel to finish shutting down, and kill it if it doesn't""" self.log.info("Kernel shutdown: %s", kernel_id) @@ -468,7 +468,7 @@ def remove_restart_callback( """remove a callback for the KernelRestarter""" @kernel_method - def get_connection_info(self, kernel_id: str) -> t.Dict[str, t.Any]: # type:ignore[empty-body] + def get_connection_info(self, kernel_id: str) -> dict[str, t.Any]: # type:ignore[empty-body] """Return a dictionary of connection data for a kernel. Parameters @@ -487,7 +487,7 @@ def get_connection_info(self, kernel_id: str) -> t.Dict[str, t.Any]: # type:ign @kernel_method def connect_iopub( # type:ignore[empty-body] - self, kernel_id: str, identity: t.Optional[bytes] = None + self, kernel_id: str, identity: bytes | None = None ) -> socket.socket: """Return a zmq Socket connected to the iopub channel. @@ -505,7 +505,7 @@ def connect_iopub( # type:ignore[empty-body] @kernel_method def connect_shell( # type:ignore[empty-body] - self, kernel_id: str, identity: t.Optional[bytes] = None + self, kernel_id: str, identity: bytes | None = None ) -> socket.socket: """Return a zmq Socket connected to the shell channel. @@ -523,7 +523,7 @@ def connect_shell( # type:ignore[empty-body] @kernel_method def connect_control( # type:ignore[empty-body] - self, kernel_id: str, identity: t.Optional[bytes] = None + self, kernel_id: str, identity: bytes | None = None ) -> socket.socket: """Return a zmq Socket connected to the control channel. @@ -541,7 +541,7 @@ def connect_control( # type:ignore[empty-body] @kernel_method def connect_stdin( # type:ignore[empty-body] - self, kernel_id: str, identity: t.Optional[bytes] = None + self, kernel_id: str, identity: bytes | None = None ) -> socket.socket: """Return a zmq Socket connected to the stdin channel. @@ -559,7 +559,7 @@ def connect_stdin( # type:ignore[empty-body] @kernel_method def connect_hb( # type:ignore[empty-body] - self, kernel_id: str, identity: t.Optional[bytes] = None + self, kernel_id: str, identity: bytes | None = None ) -> socket.socket: """Return a zmq Socket connected to the hb channel. diff --git a/jupyter_client/provisioning/factory.py b/jupyter_client/provisioning/factory.py index de2b6a2d4..fd256ca09 100644 --- a/jupyter_client/provisioning/factory.py +++ b/jupyter_client/provisioning/factory.py @@ -8,7 +8,7 @@ # See compatibility note on `group` keyword in https://docs.python.org/3/library/importlib.metadata.html#entry-points if sys.version_info < (3, 10): # pragma: no cover - from importlib_metadata import EntryPoint, entry_points + from importlib_metadata import EntryPoint, entry_points # type:ignore[import-not-found] else: # pragma: no cover from importlib.metadata import EntryPoint, entry_points @@ -43,7 +43,7 @@ class KernelProvisionerFactory(SingletonConfigurable): ) @default('default_provisioner_name') - def _default_provisioner_name_default(self): + def _default_provisioner_name_default(self) -> str: """The default provisioner name.""" return getenv(self.default_provisioner_name_env, "local-provisioner") diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py index 194ba9079..d41890f69 100644 --- a/jupyter_client/restarter.py +++ b/jupyter_client/restarter.py @@ -7,7 +7,10 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import time +import typing as t from traitlets import Bool, Dict, Float, Instance, Integer, default from traitlets.config.configurable import LoggingConfigurable @@ -52,25 +55,25 @@ class KernelRestarter(LoggingConfigurable): _last_dead = Float() @default("_last_dead") - def _default_last_dead(self): + def _default_last_dead(self) -> float: return time.time() callbacks = Dict() - def _callbacks_default(self): + def _callbacks_default(self) -> dict[str, list]: return {"restart": [], "dead": []} - def start(self): + def start(self) -> None: """Start the polling of the kernel.""" msg = "Must be implemented in a subclass" raise NotImplementedError(msg) - def stop(self): + def stop(self) -> None: """Stop the kernel polling.""" msg = "Must be implemented in a subclass" raise NotImplementedError(msg) - def add_callback(self, f, event="restart"): + def add_callback(self, f: t.Callable[..., t.Any], event: str = "restart") -> None: """register a callback to fire on a particular event Possible values for event: @@ -81,7 +84,7 @@ def add_callback(self, f, event="restart"): """ self.callbacks[event].append(f) - def remove_callback(self, f, event="restart"): + def remove_callback(self, f: t.Callable[..., t.Any], event: str = "restart") -> None: """unregister a callback to fire on a particular event Possible values for event: @@ -95,7 +98,7 @@ def remove_callback(self, f, event="restart"): except ValueError: pass - def _fire_callbacks(self, event): + def _fire_callbacks(self, event: t.Any) -> None: """fire our callbacks for a particular event""" for callback in self.callbacks[event]: try: @@ -108,7 +111,7 @@ def _fire_callbacks(self, event): exc_info=True, ) - def poll(self): + def poll(self) -> None: if self.debug: self.log.debug("Polling kernel...") if self.kernel_manager.shutting_down: diff --git a/jupyter_client/runapp.py b/jupyter_client/runapp.py index 9013f25bf..9ed4b1543 100644 --- a/jupyter_client/runapp.py +++ b/jupyter_client/runapp.py @@ -1,10 +1,13 @@ """A Jupyter console app to run files.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import queue import signal import sys import time +import typing as t from jupyter_core.application import JupyterApp, base_aliases, base_flags from traitlets import Any, Dict, Float @@ -57,14 +60,14 @@ class RunApp(JupyterApp, JupyterConsoleApp): # type:ignore[misc] """, ) - def parse_command_line(self, argv=None): + def parse_command_line(self, argv: list[str] | None = None) -> None: """Parse the command line arguments.""" super().parse_command_line(argv) self.build_kernel_argv(self.extra_args) self.filenames_to_run = self.extra_args[:] @catch_config_error - def initialize(self, argv=None): + def initialize(self, argv: list[str] | None = None) -> None: # type:ignore[override] """Initialize the app.""" self.log.debug("jupyter run: initialize...") super().initialize(argv) @@ -72,14 +75,14 @@ def initialize(self, argv=None): signal.signal(signal.SIGINT, self.handle_sigint) self.init_kernel_info() - def handle_sigint(self, *args): + def handle_sigint(self, *args: t.Any) -> None: """Handle SIGINT.""" if self.kernel_manager: self.kernel_manager.interrupt_kernel() else: self.log.error("Cannot interrupt kernels we didn't start.\n") - def init_kernel_info(self): + def init_kernel_info(self) -> None: """Wait for a kernel to be ready, and store kernel info""" timeout = self.kernel_timeout tic = time.time() @@ -97,7 +100,7 @@ def init_kernel_info(self): self.kernel_info = reply["content"] return - def start(self): + def start(self) -> None: """Start the application.""" self.log.debug("jupyter run: starting...") super().start() diff --git a/jupyter_client/session.py b/jupyter_client/session.py index 16373ce3b..2cc2874ef 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -61,7 +61,7 @@ # ----------------------------------------------------------------------------- -def squash_unicode(obj): +def squash_unicode(obj: t.Any) -> t.Any: """coerce unicode back to bytestrings.""" if isinstance(obj, dict): for key in list(obj.keys()): @@ -89,7 +89,7 @@ def squash_unicode(obj): # disallow nan, because it's not actually valid JSON -def json_packer(obj): +def json_packer(obj: t.Any) -> bytes: """Convert a json object to a bytes.""" try: return json.dumps( @@ -117,14 +117,14 @@ def json_packer(obj): return packed -def json_unpacker(s): +def json_unpacker(s: str | bytes) -> t.Any: """Convert a json bytes or string to an object.""" if isinstance(s, bytes): s = s.decode("utf8", "replace") return json.loads(s) -def pickle_packer(o): +def pickle_packer(o: t.Any) -> bytes: """Pack an object using the pickle module.""" return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) @@ -226,10 +226,10 @@ def _context_default(self) -> zmq.Context: loop = Instance("tornado.ioloop.IOLoop") - def _loop_default(self): + def _loop_default(self) -> IOLoop: return IOLoop.current() - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: """Initialize a session factory.""" super().__init__(**kwargs) @@ -359,7 +359,7 @@ class Session(Configurable): ) @observe("packer") - def _packer_changed(self, change): + def _packer_changed(self, change: t.Any) -> None: new = change["new"] if new.lower() == "json": self.pack = json_packer @@ -380,7 +380,7 @@ def _packer_changed(self, change): ) @observe("unpacker") - def _unpacker_changed(self, change): + def _unpacker_changed(self, change: t.Any) -> None: new = change["new"] if new.lower() == "json": self.pack = json_packer @@ -401,7 +401,7 @@ def _session_default(self) -> str: return u @observe("session") - def _session_changed(self, change): + def _session_changed(self, change: t.Any) -> None: self.bsession = self.session.encode("ascii") # bsession is the session as bytes @@ -431,7 +431,7 @@ def _key_default(self) -> bytes: return new_id_bytes() @observe("key") - def _key_changed(self, change): + def _key_changed(self, change: t.Any) -> None: self._new_auth() signature_scheme = Unicode( @@ -442,7 +442,7 @@ def _key_changed(self, change): ) @observe("signature_scheme") - def _signature_scheme_changed(self, change): + def _signature_scheme_changed(self, change: t.Any) -> None: new = change["new"] if not new.startswith("hmac-"): raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) @@ -479,7 +479,7 @@ def _new_auth(self) -> None: keyfile = Unicode("", config=True, help="""path to file containing execution key.""") @observe("keyfile") - def _keyfile_changed(self, change): + def _keyfile_changed(self, change: t.Any) -> None: with open(change["new"], "rb") as f: self.key = f.read().strip() @@ -491,7 +491,7 @@ def _keyfile_changed(self, change): pack = Any(default_packer) # the actual packer function @observe("pack") - def _pack_changed(self, change): + def _pack_changed(self, change: t.Any) -> None: new = change["new"] if not callable(new): raise TypeError("packer must be callable, not %s" % type(new)) @@ -499,7 +499,7 @@ def _pack_changed(self, change): unpack = Any(default_unpacker) # the actual packer function @observe("unpack") - def _unpack_changed(self, change): + def _unpack_changed(self, change: t.Any) -> None: # unpacker is not checked - it is assumed to be new = change["new"] if not callable(new): @@ -525,7 +525,7 @@ def _unpack_changed(self, change): """, ) - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: """create a Session object Parameters @@ -588,7 +588,7 @@ def clone(self) -> Session: """ # make a copy new_session = type(self)() - for name in self.traits(): + for name in self.traits(): # type:ignore[no-untyped-call] setattr(new_session, name, getattr(self, name)) # fork digest_history new_session.digest_history = set() diff --git a/jupyter_client/ssh/forward.py b/jupyter_client/ssh/forward.py index 47e63f22e..e2f28d218 100644 --- a/jupyter_client/ssh/forward.py +++ b/jupyter_client/ssh/forward.py @@ -85,7 +85,7 @@ def handle(self): logger.debug("Tunnel closed ") -def forward_tunnel(local_port, remote_host, remote_port, transport): +def forward_tunnel(local_port: int, remote_host: str, remote_port: int, transport: t.Any) -> None: """Forward an ssh tunnel.""" # this is a little convoluted, but lets me configure things for the Handler diff --git a/jupyter_client/ssh/tunnel.py b/jupyter_client/ssh/tunnel.py index e98e46906..6ddeb0f2b 100644 --- a/jupyter_client/ssh/tunnel.py +++ b/jupyter_client/ssh/tunnel.py @@ -5,6 +5,8 @@ # Copyright (C) 2011- PyZMQ Developers # # Redistributed from IPython under the terms of the BSD License. +from __future__ import annotations + import atexit import os import re @@ -14,6 +16,7 @@ import warnings from getpass import getpass, getuser from multiprocessing import Process +from typing import Any, cast try: with warnings.catch_warnings(): @@ -36,7 +39,7 @@ class SSHException(Exception): # type:ignore[no-redef] # noqa pexpect = None -def select_random_ports(n): +def select_random_ports(n: int) -> list[int]: """Select and return n random ports that are available.""" ports = [] sockets = [] @@ -56,7 +59,7 @@ def select_random_ports(n): _password_pat = re.compile((br"pass(word|phrase):"), re.IGNORECASE) -def try_passwordless_ssh(server, keyfile, paramiko=None): +def try_passwordless_ssh(server: str, keyfile: str | None, paramiko: Any = None) -> Any: """Attempt to make an ssh connection without a password. This is mainly used for requiring password input only once when many tunnels may be connected to the same server. @@ -69,7 +72,7 @@ def try_passwordless_ssh(server, keyfile, paramiko=None): return f(server, keyfile) -def _try_passwordless_openssh(server, keyfile): +def _try_passwordless_openssh(server: str, keyfile: str | None) -> bool: """Try passwordless login with shell ssh command.""" if pexpect is None: msg = "pexpect unavailable, use paramiko" @@ -99,7 +102,7 @@ def _try_passwordless_openssh(server, keyfile): return False -def _try_passwordless_paramiko(server, keyfile): +def _try_passwordless_paramiko(server: str, keyfile: str | None) -> bool: """Try passwordless login with paramiko.""" if paramiko is None: msg = "Paramiko unavailable, " # type:ignore[unreachable] @@ -121,7 +124,15 @@ def _try_passwordless_paramiko(server, keyfile): return True -def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60): +def tunnel_connection( + socket: socket.socket, + addr: str, + server: str, + keyfile: str | None = None, + password: str | None = None, + paramiko: Any = None, + timeout: int = 60, +) -> int: """Connect a socket to an address via an ssh tunnel. This is a wrapper for socket.connect(addr), when addr is not accessible @@ -142,7 +153,14 @@ def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramik return tunnel -def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60): +def open_tunnel( + addr: str, + server: str, + keyfile: str | None = None, + password: str | None = None, + paramiko: Any = None, + timeout: int = 60, +) -> tuple[str, int]: """Open a tunneled connection from a 0MQ url. For use inside tunnel_connection. @@ -157,25 +175,31 @@ def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeou lport = select_random_ports(1)[0] _, addr = addr.split("://") ip, rport = addr.split(":") - rport = int(rport) + rport_int = int(rport) paramiko = sys.platform == "win32" if paramiko is None else paramiko_tunnel tunnelf = paramiko_tunnel if paramiko else openssh_tunnel tunnel = tunnelf( lport, - rport, + rport_int, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout, ) - return "tcp://127.0.0.1:%i" % lport, tunnel + return "tcp://127.0.0.1:%i" % lport, cast(int, tunnel) def openssh_tunnel( - lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60 -): + lport: int, + rport: int, + server: str, + remoteip: str = "127.0.0.1", + keyfile: str | None = None, + password: str | None | bool = None, + timeout: int = 60, +) -> int: """Create an ssh tunnel using command-line ssh that connects port lport on this machine to localhost:rport on server. The tunnel will automatically close when not in use, remaining open @@ -277,26 +301,32 @@ def openssh_tunnel( failed = True -def _stop_tunnel(cmd): +def _stop_tunnel(cmd: Any) -> None: pexpect.run(cmd) -def _split_server(server): +def _split_server(server: str) -> tuple[str, str, int]: if "@" in server: username, server = server.split("@", 1) else: username = getuser() if ":" in server: - server, port = server.split(":") - port = int(port) + server, port_str = server.split(":") + port = int(port_str) else: port = 22 return username, server, port def paramiko_tunnel( - lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60 -): + lport: int, + rport: int, + server: str, + remoteip: str = "127.0.0.1", + keyfile: str | None = None, + password: str | None = None, + timeout: float = 60, +) -> Process: """launch a tunner with paramiko in a subprocess. This should only be used when shell ssh is unavailable (e.g. Windows). @@ -353,7 +383,14 @@ def paramiko_tunnel( return p -def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): +def _paramiko_tunnel( + lport: int, + rport: int, + server: str, + remoteip: str, + keyfile: str | None = None, + password: str | None = None, +) -> None: """Function for actually starting a paramiko tunnel, to be passed to multiprocessing.Process(target=this), and not called directly. """ diff --git a/jupyter_client/threaded.py b/jupyter_client/threaded.py index 6fa1e0ed0..48e0ae771 100644 --- a/jupyter_client/threaded.py +++ b/jupyter_client/threaded.py @@ -57,7 +57,7 @@ def __init__( self.ioloop = loop f: Future = Future() - def setup_stream(): + def setup_stream() -> None: try: assert self.socket is not None self.stream = zmqstream.ZMQStream(self.socket, self.ioloop) @@ -92,7 +92,7 @@ def close(self) -> None: # c.f.Future for threadsafe results f: Future = Future() - def close_stream(): + def close_stream() -> None: try: if self.stream is not None: self.stream.close(linger=0) @@ -129,7 +129,7 @@ def send(self, msg: Dict[str, Any]) -> None: thread control of the action. """ - def thread_send(): + def thread_send() -> None: assert self.session is not None self.session.send(self.stream, msg) @@ -192,7 +192,7 @@ def flush(self, timeout: float = 1.0) -> None: _msg = "Attempt to flush closed stream" raise OSError(_msg) - def flush(f): + def flush(f: Any) -> None: try: self._flush() except Exception as e: @@ -224,7 +224,7 @@ class IOLoopThread(Thread): _exiting = False ioloop = None - def __init__(self): + def __init__(self) -> None: """Initialize an io loop thread.""" super().__init__() self.daemon = True @@ -254,7 +254,7 @@ def run(self) -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - async def assign_ioloop(): + async def assign_ioloop() -> None: self.ioloop = IOLoop.current() loop.run_until_complete(assign_ioloop()) @@ -265,7 +265,7 @@ async def assign_ioloop(): loop.run_until_complete(self._async_run()) - async def _async_run(self): + async def _async_run(self) -> None: """Run forever (until self._exiting is set)""" while not self._exiting: await asyncio.sleep(1) @@ -282,7 +282,7 @@ def stop(self) -> None: self.close() self.ioloop = None - def __del__(self): + def __del__(self) -> None: self.close() def close(self) -> None: @@ -298,9 +298,10 @@ class ThreadedKernelClient(KernelClient): """A KernelClient that provides thread-safe sockets with async callbacks on message replies.""" @property - def ioloop(self): + def ioloop(self) -> Optional[IOLoop]: # type:ignore[override] if self.ioloop_thread: return self.ioloop_thread.ioloop + return None ioloop_thread = Instance(IOLoopThread, allow_none=True) diff --git a/jupyter_client/utils.py b/jupyter_client/utils.py index 37eb3dc1a..555777470 100644 --- a/jupyter_client/utils.py +++ b/jupyter_client/utils.py @@ -3,14 +3,17 @@ - provides utility wrappers to run asynchronous functions in a blocking environment. - vendor functions from ipython_genutils that should be retired at some point. """ +from __future__ import annotations + import os +from typing import Sequence from jupyter_core.utils import ensure_async, run_sync # noqa: F401 # noqa: F401 from .session import utcnow # noqa -def _filefind(filename, path_dirs=None): +def _filefind(filename: str, path_dirs: str | Sequence[str] | None = None) -> str: """Find a file by looking through a sequence of paths. This iterates through a sequence of paths looking for a file and returns @@ -64,7 +67,7 @@ def _filefind(filename, path_dirs=None): raise OSError(msg) -def _expand_path(s): +def _expand_path(s: str) -> str: """Expand $VARS and ~names in a string, like a shell :Examples: diff --git a/jupyter_client/win_interrupt.py b/jupyter_client/win_interrupt.py index 20a3a7f69..c823d4db3 100644 --- a/jupyter_client/win_interrupt.py +++ b/jupyter_client/win_interrupt.py @@ -4,11 +4,10 @@ ipykernel.parentpoller.ParentPollerWindows for a Python implementation. """ import ctypes -from typing import no_type_check +from typing import Any -@no_type_check -def create_interrupt_event(): +def create_interrupt_event() -> Any: """Create an interrupt event handle. The parent process should call this to create the @@ -33,12 +32,11 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # noqa sa.lpSecurityDescriptor = 0 sa.bInheritHandle = 1 - return ctypes.windll.kernel32.CreateEventA( + return ctypes.windll.kernel32.CreateEventA( # type:ignore[attr-defined] sa_p, False, False, "" # lpEventAttributes # bManualReset # bInitialState ) # lpName -@no_type_check -def send_interrupt(interrupt_handle): +def send_interrupt(interrupt_handle: Any) -> None: """Sends an interrupt event using the specified handle.""" - ctypes.windll.kernel32.SetEvent(interrupt_handle) + ctypes.windll.kernel32.SetEvent(interrupt_handle) # type:ignore[attr-defined] diff --git a/pyproject.toml b/pyproject.toml index edf49e0f7..475371bc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ nowarn = "test -W default {args}" features = ["test"] dependencies = ["mypy~=1.6", "traitlets>=5.11.2", "jupyter_core>=5.3.2"] [tool.hatch.envs.typing.scripts] -test = "mypy --install-types --non-interactive {args:.}" +test = "mypy --install-types --non-interactive {args}" [tool.hatch.envs.lint] dependencies = [ @@ -166,28 +166,17 @@ relative_files = true source = ["jupyter_client"] [tool.mypy] -check_untyped_defs = true +files = "jupyter_client" +python_version = "3.8" +strict = true disallow_any_generics = false -disallow_incomplete_defs = true -disallow_untyped_decorators = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] -no_implicit_optional = true no_implicit_reexport = false pretty = true show_error_context = true show_error_codes = true -strict_equality = true -strict_optional = true -warn_unused_configs = true -warn_redundant_casts = true warn_return_any = false warn_unreachable = true -warn_unused_ignores = true - -[[tool.mypy.overrides]] -module = "tests.*" -disable_error_code = ["ignore-without-code"] -warn_unreachable = false [tool.black] line-length = 100 @@ -300,4 +289,4 @@ fail-under=90 exclude = ["docs", "test"] [tool.repo-review] -ignore = ["PY007", "PP308", "GH102", "PC140", "MY101"] +ignore = ["PY007", "PP308", "GH102", "PC140"] diff --git a/tests/test_kernelspecapp.py b/tests/test_kernelspecapp.py index 7fa161eed..a78e5fbd6 100644 --- a/tests/test_kernelspecapp.py +++ b/tests/test_kernelspecapp.py @@ -27,7 +27,7 @@ def test_kernelspec_sub_apps(jp_kernel_dir): app1 = ListKernelSpecs() app1.kernel_spec_manager.kernel_dirs.append(kernel_dir) specs = app1.start() - assert 'echo' in specs + assert specs and 'echo' in specs app2 = RemoveKernelSpec(spec_names=['echo'], force=True) app2.kernel_spec_manager.kernel_dirs.append(kernel_dir) @@ -36,7 +36,7 @@ def test_kernelspec_sub_apps(jp_kernel_dir): app3 = ListKernelSpecs() app3.kernel_spec_manager.kernel_dirs.append(kernel_dir) specs = app3.start() - assert 'echo' not in specs + assert specs and 'echo' not in specs def test_kernelspec_app(): diff --git a/tests/test_localinterfaces.py b/tests/test_localinterfaces.py index 86edc4e56..b299fae56 100644 --- a/tests/test_localinterfaces.py +++ b/tests/test_localinterfaces.py @@ -11,7 +11,7 @@ def test_load_ips(): # Override the machinery that skips it if it was called before - localinterfaces._load_ips.called = False + localinterfaces._load_ips.called = False # type:ignore[attr-defined] # Just check this doesn't error localinterfaces._load_ips(suppress_exceptions=False) From c4673c7f9b137681ff53fed0ee8fad91219e7c48 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Sun, 15 Oct 2023 21:16:33 -0500 Subject: [PATCH 2/3] fix imports --- jupyter_client/asynchronous/client.py | 2 ++ jupyter_client/blocking/client.py | 2 ++ jupyter_client/clientabc.py | 14 ++++---- jupyter_client/connect.py | 48 ++++++++++++++------------- 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 53c68ffba..118734161 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -1,6 +1,8 @@ """Implements an async kernel client""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import typing as t import zmq.asyncio diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index bff55f5a7..5c815eb8d 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -4,6 +4,8 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import typing as t from traitlets import Type diff --git a/jupyter_client/clientabc.py b/jupyter_client/clientabc.py index cc14bda67..d003fe173 100644 --- a/jupyter_client/clientabc.py +++ b/jupyter_client/clientabc.py @@ -8,8 +8,10 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations + import abc -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from .channelsabc import ChannelABC @@ -32,23 +34,23 @@ def kernel(self) -> Any: pass @abc.abstractproperty - def shell_channel_class(self) -> Type[ChannelABC]: + def shell_channel_class(self) -> type[ChannelABC]: pass @abc.abstractproperty - def iopub_channel_class(self) -> Type[ChannelABC]: + def iopub_channel_class(self) -> type[ChannelABC]: pass @abc.abstractproperty - def hb_channel_class(self) -> Type[ChannelABC]: + def hb_channel_class(self) -> type[ChannelABC]: pass @abc.abstractproperty - def stdin_channel_class(self) -> Type[ChannelABC]: + def stdin_channel_class(self) -> type[ChannelABC]: pass @abc.abstractproperty - def control_channel_class(self) -> Type[ChannelABC]: + def control_channel_class(self) -> type[ChannelABC]: pass # -------------------------------------------------------------------------- diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index 2564b7693..a634be3de 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -5,6 +5,8 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import errno import glob import json @@ -14,7 +16,7 @@ import tempfile import warnings from getpass import getpass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Union, cast import zmq from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write @@ -34,7 +36,7 @@ def write_connection_file( - fname: Optional[str] = None, + fname: str | None = None, shell_port: int = 0, iopub_port: int = 0, stdin_port: int = 0, @@ -46,7 +48,7 @@ def write_connection_file( signature_scheme: str = "hmac-sha256", kernel_name: str = "", **kwargs: Any, -) -> Tuple[str, KernelConnectionInfo]: +) -> tuple[str, KernelConnectionInfo]: """Generates a JSON config file, including the selection of random ports. Parameters @@ -96,8 +98,8 @@ def write_connection_file( # Find open ports as necessary. - ports: List[int] = [] - sockets: List[socket.socket] = [] + ports: list[int] = [] + sockets: list[socket.socket] = [] ports_needed = ( int(shell_port <= 0) + int(iopub_port <= 0) @@ -174,8 +176,8 @@ def write_connection_file( def find_connection_file( filename: str = "kernel-*.json", - path: Optional[Union[str, List[str]]] = None, - profile: Optional[str] = None, + path: str | list[str] | None = None, + profile: str | None = None, ) -> str: """find a connection file, and return its absolute path. @@ -237,10 +239,10 @@ def find_connection_file( def tunnel_to_kernel( - connection_info: Union[str, KernelConnectionInfo], + connection_info: str | KernelConnectionInfo, sshserver: str, - sshkey: Optional[str] = None, -) -> Tuple[Any, ...]: + sshkey: str | None = None, +) -> tuple[Any, ...]: """tunnel connections to a kernel via ssh This will open five SSH tunnels from localhost on this machine to the @@ -287,7 +289,7 @@ def tunnel_to_kernel( remote_ip = cf["ip"] if tunnel.try_passwordless_ssh(sshserver, sshkey): - password: Union[bool, str] = False + password: bool | str = False else: password = getpass("SSH Password for %s: " % sshserver) @@ -315,7 +317,7 @@ def tunnel_to_kernel( class ConnectionFileMixin(LoggingConfigurable): """Mixin for configurable classes that work with connection files""" - data_dir: Union[str, Unicode] = Unicode() + data_dir: str | Unicode = Unicode() def _data_dir_default(self) -> str: return jupyter_data_dir() @@ -334,7 +336,7 @@ def _data_dir_default(self) -> str: _connection_file_written = Bool(False) transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True) - kernel_name: Union[str, Unicode] = Unicode() + kernel_name: str | Unicode = Unicode() context = Instance(zmq.Context) @@ -369,10 +371,10 @@ def _ip_changed(self, change: Any) -> None: control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]") # names of the ports with random assignment - _random_port_names: Optional[List[str]] = None + _random_port_names: list[str] | None = None @property - def ports(self) -> List[int]: + def ports(self) -> list[int]: return [getattr(self, name) for name in port_names] # The Session to use for communication with the kernel. @@ -516,7 +518,7 @@ def write_connection_file(self, **kwargs: Any) -> None: self._connection_file_written = True - def load_connection_file(self, connection_file: Optional[str] = None) -> None: + def load_connection_file(self, connection_file: str | None = None) -> None: """Load connection info from JSON dict in self.connection_file. Parameters @@ -643,7 +645,7 @@ def _make_url(self, channel: str) -> str: return f"{transport}://{ip}-{port}" def _create_connected_socket( - self, channel: str, identity: Optional[bytes] = None + self, channel: str, identity: bytes | None = None ) -> zmq.sugar.socket.Socket: """Create a zmq Socket and connect it to the kernel.""" url = self._make_url(channel) @@ -657,25 +659,25 @@ def _create_connected_socket( sock.connect(url) return sock - def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: + def connect_iopub(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the IOPub channel""" sock = self._create_connected_socket("iopub", identity=identity) sock.setsockopt(zmq.SUBSCRIBE, b"") return sock - def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: + def connect_shell(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Shell channel""" return self._create_connected_socket("shell", identity=identity) - def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: + def connect_stdin(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the StdIn channel""" return self._create_connected_socket("stdin", identity=identity) - def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: + def connect_hb(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Heartbeat channel""" return self._create_connected_socket("hb", identity=identity) - def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: + def connect_control(self, identity: bytes | None = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Control channel""" return self._create_connected_socket("control", identity=identity) @@ -693,7 +695,7 @@ class is attempting to resolve (minimize). def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.currently_used_ports: Set[int] = set() + self.currently_used_ports: set[int] = set() def find_available_port(self, ip: str) -> int: while True: From d5c604d7fce7bfcbcc348b8a33a2e37a74ab18e5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Oct 2023 18:48:46 -0500 Subject: [PATCH 3/3] ignore pep585 and 604 typing changes --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 475371bc9..9a0b3fbf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -251,6 +251,11 @@ ignore = [ "PLW0603", # Mutable class attributes should be annotated with `typing.ClassVar` "RUF012", + # non-pep585-annotation + "UP006", + # non-pep604-annotation + "UP007", + ] unfixable = [ # Don't touch print statements