Skip to content

Commit

Permalink
fix(typing): Resolve mypy==1.11.0 issues in plugin_registry (vega…
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Jul 20, 2024
1 parent b9d3fd4 commit b1b8c6e
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 43 deletions.
2 changes: 1 addition & 1 deletion altair/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
VegaLiteCompilerType = Callable[[dict], dict]


class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType]):
class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType, dict]):
pass
18 changes: 8 additions & 10 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
Dict,
overload,
runtime_checkable,
Callable,
)
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, ParamSpec, Concatenate
from pathlib import Path
from functools import partial
import sys
Expand Down Expand Up @@ -82,17 +83,14 @@ def is_data_type(obj: Any) -> TypeIs[DataType]:
# VegaLite spec, after the Data model has been put into a schema compliant
# form.
# ==============================================================================
class DataTransformerType(Protocol):
@overload
def __call__(self, data: None = None, **kwargs) -> DataTransformerType: ...
@overload
def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: ...
def __call__(
self, data: DataType | None = None, **kwargs
) -> DataTransformerType | VegaLiteDataDict: ...

P = ParamSpec("P")
# NOTE: `Any` required due to the complexity of existing signatures imported in `altair.vegalite.v5.data.py`
R = TypeVar("R", VegaLiteDataDict, Any)
DataTransformerType = Callable[Concatenate[DataType, P], R]

class DataTransformerRegistry(PluginRegistry[DataTransformerType]):

class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]):
_global_settings = {"consolidate_datasets": True}

@property
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
]


class RendererRegistry(PluginRegistry[RendererType]):
class RendererRegistry(PluginRegistry[RendererType, MimeBundleType]):
entrypoint_err_messages = {
"notebook": textwrap.dedent(
"""
Expand Down
95 changes: 67 additions & 28 deletions altair/utils/plugin_registry.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
from __future__ import annotations

from functools import partial
from typing import Any, Generic, TypeVar, cast, Callable, TYPE_CHECKING
from typing import Any, Generic, cast, Callable, TYPE_CHECKING
from typing_extensions import TypeAliasType, TypeVar, TypeIs

from importlib.metadata import entry_points

from altair.utils.deprecation import deprecated_warn

if TYPE_CHECKING:
from types import TracebackType

T = TypeVar("T")
R = TypeVar("R")
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,))
PluginT = TypeVar("PluginT", bound=Plugin[Any])
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]]


def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]:
"""Converts a type to guard function.
Added for compatibility with original `PluginRegistry` default.
"""

def func(obj: object, /) -> TypeIs[type[T]]:
return isinstance(obj, tp)

PluginType = TypeVar("PluginType")
return func


class NoSuchEntryPoint(Exception):
Expand Down Expand Up @@ -49,7 +67,7 @@ def __repr__(self) -> str:
return f"{self.registry.__class__.__name__}.enable({self.name!r})"


class PluginRegistry(Generic[PluginType]):
class PluginRegistry(Generic[PluginT, R]):
"""A registry for plugins.
This is a plugin registry that allows plugins to be loaded/registered
Expand All @@ -74,26 +92,44 @@ class PluginRegistry(Generic[PluginType]):
# in the registry rather than passed to the plugins
_global_settings: dict[str, Any] = {}

def __init__(self, entry_point_group: str = "", plugin_type: type = Callable): # type: ignore[assignment]
def __init__(
self, entry_point_group: str = "", plugin_type: IsPlugin = callable
) -> None:
"""Create a PluginRegistry for a named entry point group.
Parameters
==========
entry_point_group: str
The name of the entry point group.
plugin_type: object
A type that will optionally be used for runtime type checking of
loaded plugins using isinstance.
plugin_type
A type narrowing function that will optionally be used for runtime
type checking loaded plugins.
References
==========
https://typing.readthedocs.io/en/latest/spec/narrowing.html
"""
self.entry_point_group: str = entry_point_group
self.plugin_type: type[Any] = plugin_type
self._active: PluginType | None = None
self.plugin_type: IsPlugin
if plugin_type is not callable and isinstance(plugin_type, type):
msg = (
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n"
f"{type(self).__name__!r}(plugin_type)\n\n"
f"See also:\n"
f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n"
f"https://docs.astral.sh/ruff/rules/assert/"
)
deprecated_warn(msg, version="5.4.0")
self.plugin_type = cast(IsPlugin, _is_type(plugin_type))
else:
self.plugin_type = plugin_type
self._active: Plugin[R] | None = None
self._active_name: str = ""
self._plugins: dict[str, PluginType] = {}
self._plugins: dict[str, PluginT] = {}
self._options: dict[str, Any] = {}
self._global_settings: dict[str, Any] = self.__class__._global_settings.copy()

def register(self, name: str, value: PluginType | Any | None) -> PluginType | None:
def register(self, name: str, value: PluginT | None) -> PluginT | None:
"""Register a plugin by name and value.
This method is used for explicit registration of a plugin and shouldn't be
Expand All @@ -113,12 +149,12 @@ def register(self, name: str, value: PluginType | Any | None) -> PluginType | No
"""
if value is None:
return self._plugins.pop(name, None)
else:
assert isinstance(
value, self.plugin_type
) # Should ideally be fixed by better annotating plugin_type
elif self.plugin_type(value):
self._plugins[name] = value
return value
else:
msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}"
raise TypeError(msg)

def names(self) -> list[str]:
"""List the names of the registered and entry points plugins."""
Expand Down Expand Up @@ -163,7 +199,7 @@ def _enable(self, name: str, **options) -> None:
raise ValueError(self.entrypoint_err_messages[name]) from err
else:
raise NoSuchEntryPoint(self.entry_point_group, name) from err
value = cast(PluginType, ep.load())
value = cast(PluginT, ep.load())
self.register(name, value)
self._active_name = name
self._active = self._plugins[name]
Expand Down Expand Up @@ -204,18 +240,21 @@ def options(self) -> dict[str, Any]:
"""Return the current options dictionary"""
return self._options

def get(self) -> PluginType | Callable[..., Any] | None:
def get(self) -> partial[R] | Plugin[R] | None:
"""Return the currently active plugin."""
if self._options:
if func := self._active:
# NOTE: Fully do not understand this one
# error: Argument 1 to "partial" has incompatible type "PluginType"; expected "Callable[..., Never]"
return partial(func, **self._options) # type: ignore[arg-type]
else:
msg = "Unclear what this meant by passing to curry."
raise TypeError(msg)
else:
return self._active
if (func := self._active) and self.plugin_type(func):
return partial(func, **self._options) if self._options else func
elif self._active is not None:
msg = (
f"{type(self).__name__!r} requires all plugins to be callable objects, "
f"but {type(self._active).__name__!r} is not callable."
)
raise TypeError(msg)
elif TYPE_CHECKING:
# NOTE: The `None` return is implicit, but `mypy` isn't satisfied
# - `ruff` will factor out explicit `None` return
# - `pyright` has no issue
raise NotImplementedError

def __repr__(self) -> str:
return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})"
Expand All @@ -228,6 +267,6 @@ def importlib_metadata_get(group):
# also get compatibility with the importlib_metadata package which had a different
# deprecation cycle for 'get'
if hasattr(ep, "select"):
return ep.select(group=group)
return ep.select(group=group) # pyright: ignore
else:
return ep.get(group, [])
2 changes: 1 addition & 1 deletion altair/utils/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
ThemeType = Callable[..., dict]


class ThemeRegistry(PluginRegistry[ThemeType]):
class ThemeRegistry(PluginRegistry[ThemeType, dict]):
pass
3 changes: 2 additions & 1 deletion altair/vegalite/v5/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP)
data_transformers.register("default", default_data_transformer)
data_transformers.register("json", to_json)
data_transformers.register("csv", to_csv)
# FIXME: `to_csv` cannot accept all `DataType` https://github.com/vega/altair/issues/3441
data_transformers.register("csv", to_csv) # type: ignore[arg-type]
data_transformers.register("vegafusion", vegafusion_data_transformer)
data_transformers.enable("default")

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable


class TypedCallableRegistry(PluginRegistry[Callable[[int], int]]):
class TypedCallableRegistry(PluginRegistry[Callable[[int], int], int]):
pass


Expand Down

0 comments on commit b1b8c6e

Please sign in to comment.