diff --git a/altair/utils/core.py b/altair/utils/core.py index 61e370b1d..c429a898f 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -1,7 +1,7 @@ """ Utility routines """ -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import deepcopy import json import itertools @@ -9,7 +9,8 @@ import sys import traceback import warnings -from typing import Callable, TypeVar, Any +from typing import Callable, TypeVar, Any, Union, Dict, Optional, Tuple, Sequence, Type +from types import ModuleType import jsonschema import pandas as pd @@ -23,9 +24,9 @@ from typing_extensions import ParamSpec if sys.version_info >= (3, 8): - from typing import Protocol + from typing import Literal, Protocol else: - from typing_extensions import Protocol + from typing_extensions import Literal, Protocol try: from pandas.api.types import infer_dtype as _infer_dtype @@ -42,7 +43,7 @@ def __dataframe__(self, *args, **kwargs): ... -def infer_dtype(value): +def infer_dtype(value: object) -> str: """Infer the dtype of the value. This is a compatibility function for pandas infer_dtype, @@ -53,10 +54,10 @@ def infer_dtype(value): _infer_dtype([1], skipna=False) except TypeError: # pandas < 0.21.0 don't support skipna keyword - infer_dtype._supports_skipna = False + infer_dtype._supports_skipna = False # type: ignore[attr-defined] else: - infer_dtype._supports_skipna = True - if infer_dtype._supports_skipna: + infer_dtype._supports_skipna = True # type: ignore[attr-defined] + if infer_dtype._supports_skipna: # type: ignore[attr-defined] return _infer_dtype(value, skipna=False) else: return _infer_dtype(value) @@ -200,16 +201,20 @@ def infer_dtype(value): ] -def infer_vegalite_type(data): +_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] + + +def infer_vegalite_type( + data: object, +) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') Parameters ---------- - data: Numpy array or Pandas Series + data: object """ - # Otherwise, infer based on the dtype of the input typ = infer_dtype(data) if typ in [ @@ -220,7 +225,7 @@ def infer_vegalite_type(data): "complex", ]: return "quantitative" - elif typ == "categorical" and data.cat.ordered: + elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: return ("ordinal", data.cat.categories.tolist()) elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: return "nominal" @@ -243,7 +248,7 @@ def infer_vegalite_type(data): return "nominal" -def merge_props_geom(feat): +def merge_props_geom(feat: dict) -> dict: """ Merge properties with geometry * Overwrites 'type' and 'geometry' entries if existing @@ -261,7 +266,7 @@ def merge_props_geom(feat): return props_geom -def sanitize_geo_interface(geo): +def sanitize_geo_interface(geo: MutableMapping) -> dict: """Santize a geo_interface to prepare it for serialization. * Make a copy @@ -278,23 +283,23 @@ def sanitize_geo_interface(geo): geo[key] = geo[key].tolist() # convert (nested) tuples to lists - geo = json.loads(json.dumps(geo)) + geo_dct: dict = json.loads(json.dumps(geo)) # sanitize features - if geo["type"] == "FeatureCollection": - geo = geo["features"] - if len(geo) > 0: - for idx, feat in enumerate(geo): - geo[idx] = merge_props_geom(feat) - elif geo["type"] == "Feature": - geo = merge_props_geom(geo) + if geo_dct["type"] == "FeatureCollection": + geo_dct = geo_dct["features"] + if len(geo_dct) > 0: + for idx, feat in enumerate(geo_dct): + geo_dct[idx] = merge_props_geom(feat) + elif geo_dct["type"] == "Feature": + geo_dct = merge_props_geom(geo_dct) else: - geo = {"type": "Feature", "geometry": geo} + geo_dct = {"type": "Feature", "geometry": geo_dct} - return geo + return geo_dct -def sanitize_dataframe(df): # noqa: C901 +def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901 """Sanitize a DataFrame to prepare it for serialization. * Make a copy @@ -433,13 +438,13 @@ def sanitize_arrow_table(pa_table): def parse_shorthand( - shorthand, - data=None, - parse_aggregates=True, - parse_window_ops=False, - parse_timeunits=True, - parse_types=True, -): + shorthand: Union[Dict[str, Any], str], + data: Optional[pd.DataFrame] = None, + parse_aggregates: bool = True, + parse_window_ops: bool = False, + parse_timeunits: bool = True, + parse_types: bool = True, +) -> Dict[str, Any]: """General tool to parse shorthand values These are of the form: @@ -554,7 +559,9 @@ def parse_shorthand( attrs = shorthand else: attrs = next( - exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand) + exp.match(shorthand).groupdict() # type: ignore[union-attr] + for exp in regexps + if exp.match(shorthand) is not None ) # Handle short form of the type expression @@ -629,21 +636,23 @@ def decorate(f: Callable[..., _V]) -> Callable[_P, _V]: return decorate -def update_nested(original, update, copy=False): +def update_nested( + original: MutableMapping, update: Mapping, copy: bool = False +) -> MutableMapping: """Update nested dictionaries Parameters ---------- - original : dict + original : MutableMapping the original (nested) dictionary, which will be updated in-place - update : dict + update : Mapping the nested dictionary of updates copy : bool, default False if True, then copy the original dictionary rather than modifying it Returns ------- - original : dict + original : MutableMapping a reference to the (modified) original dict Examples @@ -660,7 +669,7 @@ def update_nested(original, update, copy=False): for key, val in update.items(): if isinstance(val, Mapping): orig_val = original.get(key, {}) - if isinstance(orig_val, Mapping): + if isinstance(orig_val, MutableMapping): original[key] = update_nested(orig_val, val) else: original[key] = val @@ -669,7 +678,7 @@ def update_nested(original, update, copy=False): return original -def display_traceback(in_ipython=True): +def display_traceback(in_ipython: bool = True): exc_info = sys.exc_info() if in_ipython: @@ -685,16 +694,16 @@ def display_traceback(in_ipython=True): traceback.print_exception(*exc_info) -def infer_encoding_types(args, kwargs, channels): +def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): """Infer typed keyword arguments for args and kwargs Parameters ---------- - args : tuple - List of function args - kwargs : dict + args : Sequence + Sequence of function args + kwargs : MutableMapping Dict of function kwargs - channels : module + channels : ModuleType The module containing all altair encoding channel classes. Returns @@ -709,8 +718,10 @@ def infer_encoding_types(args, kwargs, channels): channel_objs = ( c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) ) - channel_to_name = {c: c._encoding_name for c in channel_objs} - name_to_channel = {} + channel_to_name: Dict[Type[SchemaBase], str] = { + c: c._encoding_name for c in channel_objs + } + name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} for chan, name in channel_to_name.items(): chans = name_to_channel.setdefault(name, {}) if chan.__name__.endswith("Datum"): diff --git a/altair/utils/data.py b/altair/utils/data.py index c72fd2ea9..9355dbef7 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -2,55 +2,71 @@ import os import random import hashlib +import sys import warnings +from typing import Union, MutableMapping, Optional, Dict, Sequence, TYPE_CHECKING, List +from types import ModuleType import pandas as pd from toolz import curried -from typing import Callable +from typing import TypeVar -from .core import sanitize_dataframe, sanitize_arrow_table +from .core import sanitize_dataframe, sanitize_arrow_table, _DataFrameLike from .core import sanitize_geo_interface from .deprecation import AltairDeprecationWarning from .plugin_registry import PluginRegistry +if sys.version_info >= (3, 8): + from typing import Protocol, TypedDict, Literal +else: + from typing_extensions import Protocol, TypedDict, Literal + + +if TYPE_CHECKING: + import pyarrow.lib + + +class _SupportsGeoInterface(Protocol): + __geo_interface__: MutableMapping + + +_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike] +_TDataType = TypeVar("_TDataType", bound=_DataType) + +_VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] +_ToValuesReturnType = Dict[str, Union[dict, List[dict]]] + + # ============================================================================== # Data transformer registry +# +# A data transformer is a callable that takes a supported data type and returns +# a transformed dictionary version of it which is compatible with the VegaLite schema. +# The dict objects will be the Data portion of the VegaLite schema. +# +# Renderers only deal with the dict form of a +# VegaLite spec, after the Data model has been put into a schema compliant +# form. # ============================================================================== -DataTransformerType = Callable +class DataTransformerType(Protocol): + def __call__(self, data: _DataType, **kwargs) -> _VegaLiteDataDict: + pass class DataTransformerRegistry(PluginRegistry[DataTransformerType]): _global_settings = {"consolidate_datasets": True} @property - def consolidate_datasets(self): + def consolidate_datasets(self) -> bool: return self._global_settings["consolidate_datasets"] @consolidate_datasets.setter - def consolidate_datasets(self, value): + def consolidate_datasets(self, value: bool) -> None: self._global_settings["consolidate_datasets"] = value # ============================================================================== -# Data model transformers -# -# A data model transformer is a pure function that takes a dict or DataFrame -# and returns a transformed version of a dict or DataFrame. The dict objects -# will be the Data portion of the VegaLite schema. The idea is that user can -# pipe a sequence of these data transformers together to prepare the data before -# it hits the renderer. -# -# In this version of Altair, renderers only deal with the dict form of a -# VegaLite spec, after the Data model has been put into a schema compliant -# form. -# -# A data model transformer has the following type signature: -# DataModelType = Union[dict, pd.DataFrame] -# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType] -# ============================================================================== - - class MaxRowsError(Exception): """Raised when a data model has too many rows.""" @@ -58,7 +74,7 @@ class MaxRowsError(Exception): @curried.curry -def limit_rows(data, max_rows=5000): +def limit_rows(data: _TDataType, max_rows: Optional[int] = 5000) -> _TDataType: """Raise MaxRowsError if the data model has more than max_rows. If max_rows is None, then do not perform any check. @@ -75,7 +91,9 @@ def limit_rows(data, max_rows=5000): if "values" in data: values = data["values"] else: - return data + # mypy gets confused as it doesn't see Dict[Any, Any] + # as equivalent to TDataType + return data # type: ignore[return-value] elif hasattr(data, "__dataframe__"): values = data if max_rows is not None and len(values) > max_rows: @@ -91,7 +109,9 @@ def limit_rows(data, max_rows=5000): @curried.curry -def sample(data, n=None, frac=None): +def sample( + data: _DataType, n: Optional[int] = None, frac: Optional[float] = None +) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]: """Reduce the size of the data model by sampling without replacement.""" check_data_type(data) if isinstance(data, pd.DataFrame): @@ -99,26 +119,61 @@ def sample(data, n=None, frac=None): elif isinstance(data, dict): if "values" in data: values = data["values"] - n = n if n else int(frac * len(values)) + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None and data is a dictionary" + ) + n = int(frac * len(values)) values = random.sample(values, n) return {"values": values} + else: + # Maybe this should raise an error or return something useful? + return None elif hasattr(data, "__dataframe__"): # experimental interchange dataframe support pi = import_pyarrow_interchange() pa_table = pi.from_dataframe(data) - n = n if n else int(frac * len(pa_table)) + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None with this data input type" + ) + n = int(frac * len(pa_table)) indices = random.sample(range(len(pa_table)), n) return pa_table.take(indices) + else: + # Maybe this should raise an error or return something useful? Currently, + # if data is of type SupportsGeoInterface it lands here + return None + + +class _JsonFormatDict(TypedDict): + type: Literal["json"] + + +class _CsvFormatDict(TypedDict): + type: Literal["csv"] + + +class _ToJsonReturnUrlDict(TypedDict): + url: str + format: _JsonFormatDict + + +class _ToCsvReturnUrlDict(TypedDict): + url: str + format: _CsvFormatDict @curried.curry def to_json( - data, - prefix="altair-data", - extension="json", - filename="{prefix}-{hash}.{extension}", - urlpath="", -): + data: _DataType, + prefix: str = "altair-data", + extension: str = "json", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> _ToJsonReturnUrlDict: """ Write the data model to a .json file and return a url based data model. """ @@ -132,12 +187,12 @@ def to_json( @curried.curry def to_csv( - data, - prefix="altair-data", - extension="csv", - filename="{prefix}-{hash}.{extension}", - urlpath="", -): + data: Union[dict, pd.DataFrame, _DataFrameLike], + prefix: str = "altair-data", + extension: str = "csv", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> _ToCsvReturnUrlDict: """Write the data model to a .csv file and return a url based data model.""" data_csv = _data_to_csv_string(data) data_hash = _compute_data_hash(data_csv) @@ -148,14 +203,16 @@ def to_csv( @curried.curry -def to_values(data): +def to_values(data: _DataType) -> _ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) if hasattr(data, "__geo_interface__"): if isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) - data = sanitize_geo_interface(data.__geo_interface__) - return {"values": data} + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data_sanitized = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] + return {"values": data_sanitized} elif isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) return {"values": data.to_dict(orient="records")} @@ -168,10 +225,12 @@ def to_values(data): pi = import_pyarrow_interchange() pa_table = sanitize_arrow_table(pi.from_dataframe(data)) return {"values": pa_table.to_pylist()} + else: + # Should never reach this state as tested by check_data_type + raise ValueError("Unrecognized data type: {}".format(type(data))) -def check_data_type(data): - """Raise if the data is not a dict or DataFrame.""" +def check_data_type(data: _DataType) -> None: if not isinstance(data, (dict, pd.DataFrame)) and not any( hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] ): @@ -185,17 +244,19 @@ def check_data_type(data): # ============================================================================== # Private utilities # ============================================================================== -def _compute_data_hash(data_str): +def _compute_data_hash(data_str: str) -> str: return hashlib.md5(data_str.encode()).hexdigest() -def _data_to_json_string(data): +def _data_to_json_string(data: _DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): if isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) - data = sanitize_geo_interface(data.__geo_interface__) + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] return json.dumps(data) elif isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) @@ -215,7 +276,7 @@ def _data_to_json_string(data): ) -def _data_to_csv_string(data): +def _data_to_csv_string(data: Union[dict, pd.DataFrame, _DataFrameLike]) -> str: """return a CSV string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): @@ -275,7 +336,7 @@ def curry(*args, **kwargs): return curried.curry(*args, **kwargs) -def import_pyarrow_interchange(): +def import_pyarrow_interchange() -> ModuleType: import pkg_resources try: diff --git a/altair/utils/display.py b/altair/utils/display.py index 730ca6534..f92b64d3f 100644 --- a/altair/utils/display.py +++ b/altair/utils/display.py @@ -1,10 +1,10 @@ import json import pkgutil import textwrap -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple, Any, Union import uuid -from .plugin_registry import PluginRegistry +from .plugin_registry import PluginRegistry, PluginEnabler from .mimebundle import spec_to_mimebundle from .schemapi import validate_jsonschema @@ -12,8 +12,19 @@ # ============================================================================== # Renderer registry # ============================================================================== -MimeBundleType = Dict[str, object] +# MimeBundleType needs to be the same as what are acceptable return values +# for _repr_mimebundle_, +# see https://ipython.readthedocs.io/en/stable/config/integrating.html#MyObject._repr_mimebundle_ +MimeBundleDataType = Dict[str, Any] +MimeBundleMetaDataType = Dict[str, Any] +MimeBundleType = Union[ + MimeBundleDataType, Tuple[MimeBundleDataType, MimeBundleMetaDataType] +] RendererType = Callable[..., MimeBundleType] +# Subtype of MimeBundleType as more specific in the values of the dictionaries +DefaultRendererReturnType = Tuple[ + Dict[str, Union[str, dict]], Dict[str, Dict[str, Any]] +] class RendererRegistry(PluginRegistry[RendererType]): @@ -37,15 +48,15 @@ class RendererRegistry(PluginRegistry[RendererType]): def set_embed_options( self, - defaultStyle=None, - renderer=None, - width=None, - height=None, - padding=None, - scaleFactor=None, - actions=None, + defaultStyle: Optional[Union[bool, str]] = None, + renderer: Optional[str] = None, + width: Optional[int] = None, + height: Optional[int] = None, + padding: Optional[int] = None, + scaleFactor: Optional[float] = None, + actions: Optional[Union[bool, Dict[str, bool]]] = None, **kwargs, - ): + ) -> PluginEnabler: """Set options for embeddings of Vega & Vega-Lite charts. Options are fully documented at https://github.com/vega/vega-embed. @@ -79,7 +90,7 @@ def set_embed_options( **kwargs : Additional options are passed directly to embed options. """ - options = { + options: Dict[str, Optional[Union[bool, str, float, Dict[str, bool]]]] = { "defaultStyle": defaultStyle, "renderer": renderer, "width": width, @@ -115,40 +126,44 @@ class Displayable: renderers: Optional[RendererRegistry] = None schema_path = ("altair", "") - def __init__(self, spec, validate=False): - # type: (dict, bool) -> None + def __init__(self, spec: dict, validate: bool = False) -> None: self.spec = spec self.validate = validate self._validate() - def _validate(self): - # type: () -> None + def _validate(self) -> None: """Validate the spec against the schema.""" data = pkgutil.get_data(*self.schema_path) assert data is not None - schema_dict = json.loads(data.decode("utf-8")) + schema_dict: dict = json.loads(data.decode("utf-8")) validate_jsonschema( self.spec, schema_dict, ) - def _repr_mimebundle_(self, include=None, exclude=None): + def _repr_mimebundle_( + self, include: Any = None, exclude: Any = None + ) -> MimeBundleType: """Return a MIME bundle for display in Jupyter frontends.""" if self.renderers is not None: - return self.renderers.get()(self.spec) + renderer_func = self.renderers.get() + assert renderer_func is not None + return renderer_func(self.spec) else: return {} -def default_renderer_base(spec, mime_type, str_repr, **options): +def default_renderer_base( + spec: dict, mime_type: str, str_repr: str, **options +) -> DefaultRendererReturnType: """A default renderer for Vega or VegaLite that works for modern frontends. This renderer works with modern frontends (JupyterLab, nteract) that know how to render the custom VegaLite MIME type listed above. """ assert isinstance(spec, dict) - bundle = {} - metadata = {} + bundle: Dict[str, Union[str, dict]] = {} + metadata: Dict[str, Dict[str, Any]] = {} bundle[mime_type] = spec bundle["text/plain"] = str_repr @@ -157,7 +172,9 @@ def default_renderer_base(spec, mime_type, str_repr, **options): return bundle, metadata -def json_renderer_base(spec, str_repr, **options): +def json_renderer_base( + spec: dict, str_repr: str, **options +) -> DefaultRendererReturnType: """A renderer that returns a MIME type of application/json. In JupyterLab/nteract this is rendered as a nice JSON tree. @@ -170,15 +187,15 @@ def json_renderer_base(spec, str_repr, **options): class HTMLRenderer: """Object to render charts as HTML, with a unique output div each time""" - def __init__(self, output_div="altair-viz-{}", **kwargs): + def __init__(self, output_div: str = "altair-viz-{}", **kwargs) -> None: self._output_div = output_div self.kwargs = kwargs @property - def output_div(self): + def output_div(self) -> str: return self._output_div.format(uuid.uuid4().hex) - def __call__(self, spec, **metadata): + def __call__(self, spec: dict, **metadata) -> Dict[str, str]: kwargs = self.kwargs.copy() kwargs.update(metadata) return spec_to_mimebundle( diff --git a/altair/utils/html.py b/altair/utils/html.py index a7a29fc74..c1084aeec 100644 --- a/altair/utils/html.py +++ b/altair/utils/html.py @@ -1,4 +1,6 @@ import json +from typing import Optional, Dict + import jinja2 @@ -201,7 +203,7 @@ ) -TEMPLATES = { +TEMPLATES: Dict[str, jinja2.Template] = { "standard": HTML_TEMPLATE, "universal": HTML_TEMPLATE_UNIVERSAL, "inline": INLINE_HTML_TEMPLATE, @@ -209,19 +211,19 @@ def spec_to_html( - spec, - mode, - vega_version, - vegaembed_version, - vegalite_version=None, - base_url="https://cdn.jsdelivr.net/npm", - output_div="vis", - embed_options=None, - json_kwds=None, - fullhtml=True, - requirejs=False, - template="standard", -): + spec: dict, + mode: str, + vega_version: str, + vegaembed_version: str, + vegalite_version: Optional[str] = None, + base_url: str = "https://cdn.jsdelivr.net/npm", + output_div: str = "vis", + embed_options: Optional[dict] = None, + json_kwds: Optional[dict] = None, + fullhtml: bool = True, + requirejs: bool = False, + template: str = "standard", +) -> str: """Embed a Vega/Vega-Lite spec into an HTML page Parameters @@ -295,11 +297,11 @@ def spec_to_html( "vega-embed", vegaembed_version ) - template = TEMPLATES.get(template, template) - if not hasattr(template, "render"): - raise ValueError("Invalid template: {0}".format(template)) + jinja_template = TEMPLATES.get(template, template) + if not hasattr(jinja_template, "render"): + raise ValueError("Invalid template: {0}".format(jinja_template)) - return template.render( + return jinja_template.render( spec=json.dumps(spec, **json_kwds), embed_options=json.dumps(embed_options), mode=mode, diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 84351b9b0..75568591d 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -461,7 +461,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: """Format param names into a table so that they are easier to read""" param_names: Tuple[str, ...] name_lengths: Tuple[int, ...] - param_names, name_lengths = zip( # type: ignore[assignment] # Mypy does think it's Tuple[Any] + param_names, name_lengths = zip( *[ (name, len(name)) for name in param_dict_keys diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index 30289160b..3aca8ea5b 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -12,15 +12,19 @@ check_data_type, ) from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry +from ..utils.data import _DataType, _ToValuesReturnType +from ..utils.plugin_registry import PluginEnabler @curried.curry -def default_data_transformer(data, max_rows=5000): +def default_data_transformer( + data: _DataType, max_rows: int = 5000 +) -> _ToValuesReturnType: return curried.pipe(data, limit_rows(max_rows=max_rows), to_values) class DataTransformerRegistry(_DataTransformerRegistry): - def disable_max_rows(self): + def disable_max_rows(self) -> PluginEnabler: """Disable the MaxRowsError.""" options = self.options if self.active == "default": diff --git a/altair/vegalite/display.py b/altair/vegalite/display.py index 91c5f33e0..1f3a13b46 100644 --- a/altair/vegalite/display.py +++ b/altair/vegalite/display.py @@ -1,4 +1,9 @@ -from ..utils.display import Displayable, default_renderer_base, json_renderer_base +from ..utils.display import ( + Displayable, + default_renderer_base, + json_renderer_base, + DefaultRendererReturnType, +) from ..utils.display import RendererRegistry, HTMLRenderer @@ -8,4 +13,5 @@ "json_renderer_base", "RendererRegistry", "HTMLRenderer", + "DefaultRendererReturnType", ) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 218bf0e79..3d61bc271 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2369,7 +2369,7 @@ def show(self, embed_opt=None, open_browser=None): connected to a browser. """ try: - import altair_viewer # type: ignore + import altair_viewer except ImportError as err: raise ValueError( "'show' method requires the altair_viewer package. " @@ -2468,7 +2468,7 @@ def facet( ) # Remove "ignore" statement once Undefined is no longer typed as Any - if data is Undefined: # type: ignore + if data is Undefined: # Remove "ignore" statement once Undefined is no longer typed as Any if self.data is Undefined: # type: ignore raise ValueError( diff --git a/altair/vegalite/v5/compiler.py b/altair/vegalite/v5/compiler.py index 60afa99f0..a4e02f79f 100644 --- a/altair/vegalite/v5/compiler.py +++ b/altair/vegalite/v5/compiler.py @@ -1,6 +1,14 @@ +import sys + from ...utils.compiler import VegaLiteCompilerRegistry -ENTRY_POINT_GROUP: str = "altair.vegalite.v5.vegalite_compiler" +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + + +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.vegalite_compiler" vegalite_compilers = VegaLiteCompilerRegistry(entry_point_group=ENTRY_POINT_GROUP) diff --git a/altair/vegalite/v5/data.py b/altair/vegalite/v5/data.py index 703dffb32..af7af4d04 100644 --- a/altair/vegalite/v5/data.py +++ b/altair/vegalite/v5/data.py @@ -1,3 +1,5 @@ +import sys + from ..data import ( MaxRowsError, curry, @@ -11,18 +13,21 @@ DataTransformerRegistry, ) +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + # ============================================================================== # VegaLite 5 data transformers # ============================================================================== -ENTRY_POINT_GROUP = "altair.vegalite.v5.data_transformer" # type: str +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.data_transformer" -data_transformers = DataTransformerRegistry( - entry_point_group=ENTRY_POINT_GROUP -) # type: DataTransformerRegistry +data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP) data_transformers.register("default", default_data_transformer) data_transformers.register("json", to_json) data_transformers.register("csv", to_csv) diff --git a/altair/vegalite/v5/display.py b/altair/vegalite/v5/display.py index 69fa41b10..273d8f3f4 100644 --- a/altair/vegalite/v5/display.py +++ b/altair/vegalite/v5/display.py @@ -1,17 +1,27 @@ import os +import sys +from typing import Dict from ...utils.mimebundle import spec_to_mimebundle -from ..display import Displayable -from ..display import default_renderer_base -from ..display import json_renderer_base -from ..display import RendererRegistry -from ..display import HTMLRenderer +from ..display import ( + Displayable, + default_renderer_base, + json_renderer_base, + RendererRegistry, + HTMLRenderer, + DefaultRendererReturnType, +) from .schema import SCHEMA_VERSION -VEGALITE_VERSION = SCHEMA_VERSION.lstrip("v") -VEGA_VERSION = "5" -VEGAEMBED_VERSION = "6" +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + +VEGALITE_VERSION: Final = SCHEMA_VERSION.lstrip("v") +VEGA_VERSION: Final = "5" +VEGAEMBED_VERSION: Final = "6" # ============================================================================== @@ -20,15 +30,15 @@ # The MIME type for Vega-Lite 5.x releases. -VEGALITE_MIME_TYPE = "application/vnd.vegalite.v5+json" # type: str +VEGALITE_MIME_TYPE: Final = "application/vnd.vegalite.v5+json" # The entry point group that can be used by other packages to declare other # renderers that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.renderer" # type: str +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.renderer" # The display message when rendering fails -DEFAULT_DISPLAY = f"""\ +DEFAULT_DISPLAY: Final = f"""\ If you see this message, it means the renderer has not been properly enabled @@ -41,15 +51,15 @@ here = os.path.dirname(os.path.realpath(__file__)) -def mimetype_renderer(spec, **metadata): +def mimetype_renderer(spec: dict, **metadata) -> DefaultRendererReturnType: return default_renderer_base(spec, VEGALITE_MIME_TYPE, DEFAULT_DISPLAY, **metadata) -def json_renderer(spec, **metadata): +def json_renderer(spec: dict, **metadata) -> DefaultRendererReturnType: return json_renderer_base(spec, DEFAULT_DISPLAY, **metadata) -def png_renderer(spec, **metadata): +def png_renderer(spec: dict, **metadata) -> Dict[str, bytes]: return spec_to_mimebundle( spec, format="png", @@ -61,7 +71,7 @@ def png_renderer(spec, **metadata): ) -def svg_renderer(spec, **metadata): +def svg_renderer(spec: dict, **metadata) -> Dict[str, str]: return spec_to_mimebundle( spec, format="svg", @@ -102,7 +112,7 @@ class VegaLite(Displayable): schema_path = (__name__, "schema/vega-lite-schema.json") -def vegalite(spec, validate=True): +def vegalite(spec: dict, validate: bool = True) -> None: """Render and optionally validate a VegaLite 5 spec. This will use the currently enabled renderer to render the spec. diff --git a/altair/vegalite/v5/theme.py b/altair/vegalite/v5/theme.py index b536a1dde..cdaca2dbf 100644 --- a/altair/vegalite/v5/theme.py +++ b/altair/vegalite/v5/theme.py @@ -1,4 +1,6 @@ """Tools for enabling and registering chart themes""" +import sys +from typing import Dict, Union from ...utils.theme import ThemeRegistry @@ -15,27 +17,32 @@ "powerbi", ] +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + class VegaTheme: """Implementation of a builtin vega theme.""" - def __init__(self, theme): + def __init__(self, theme: str) -> None: self.theme = theme - def __call__(self): + def __call__(self) -> Dict[str, Dict[str, Dict[str, Union[str, int]]]]: return { "usermeta": {"embedOptions": {"theme": self.theme}}, "config": {"view": {"continuousWidth": 300, "continuousHeight": 300}}, } - def __repr__(self): + def __repr__(self) -> str: return "VegaTheme({!r})".format(self.theme) # The entry point group that can be used by other packages to declare other -# renderers that will be auto-detected. Explicit registration is also +# themes that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.theme" # type: str +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.theme" themes = ThemeRegistry(entry_point_group=ENTRY_POINT_GROUP) themes.register( diff --git a/pyproject.toml b/pyproject.toml index dbd489bbf..0c165cce1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,9 @@ markers = [ "save_engine: marks some of the tests which are using an external package to save a chart to e.g. a png file. This mark is used to run those tests selectively in the build GitHub Action.", ] +[tool.mypy] +warn_unused_ignores = true + [[tool.mypy.overrides]] module = [ "vega_datasets.*", @@ -229,7 +232,9 @@ module = [ "pyarrow.*", "yaml.*", "vl_convert.*", - "pandas.lib.*" + "pandas.lib.*", + "nbformat.*", + "ipykernel.*" ] ignore_missing_imports = true diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 7809acafc..5a43f574d 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -459,7 +459,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: """Format param names into a table so that they are easier to read""" param_names: Tuple[str, ...] name_lengths: Tuple[int, ...] - param_names, name_lengths = zip( # type: ignore[assignment] # Mypy does think it's Tuple[Any] + param_names, name_lengths = zip( *[ (name, len(name)) for name in param_dict_keys