From 4146bba551743b38931c647fe063d068ad7b7be1 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 15:28:36 +0000 Subject: [PATCH 01/17] Type hint vegalite/v5/theme.py --- altair/vegalite/v5/theme.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/altair/vegalite/v5/theme.py b/altair/vegalite/v5/theme.py index b536a1dde..5bd6dcecf 100644 --- a/altair/vegalite/v5/theme.py +++ b/altair/vegalite/v5/theme.py @@ -1,4 +1,5 @@ """Tools for enabling and registering chart themes""" +from typing import Dict, Union from ...utils.theme import ThemeRegistry @@ -19,23 +20,23 @@ class VegaTheme: """Implementation of a builtin vega theme.""" - def __init__(self, theme): + def __init__(self, theme: str) -> None: self.theme = theme - def __call__(self): + def __call__(self) -> Dict[str, Dict[str, Dict[str, Union[str, int]]]]: return { "usermeta": {"embedOptions": {"theme": self.theme}}, "config": {"view": {"continuousWidth": 300, "continuousHeight": 300}}, } - def __repr__(self): + def __repr__(self) -> str: return "VegaTheme({!r})".format(self.theme) # The entry point group that can be used by other packages to declare other # renderers that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.theme" # type: str +ENTRY_POINT_GROUP = "altair.vegalite.v5.theme" themes = ThemeRegistry(entry_point_group=ENTRY_POINT_GROUP) themes.register( From cc17fa2842284360e429096d92e6b3f482e4a4b7 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 18:53:10 +0000 Subject: [PATCH 02/17] Type hint display modules --- altair/utils/display.py | 69 ++++++++++++++++++++++------------- altair/vegalite/display.py | 8 +++- altair/vegalite/v5/display.py | 28 ++++++++------ altair/vegalite/v5/theme.py | 2 +- 4 files changed, 67 insertions(+), 40 deletions(-) diff --git a/altair/utils/display.py b/altair/utils/display.py index 730ca6534..f92b64d3f 100644 --- a/altair/utils/display.py +++ b/altair/utils/display.py @@ -1,10 +1,10 @@ import json import pkgutil import textwrap -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple, Any, Union import uuid -from .plugin_registry import PluginRegistry +from .plugin_registry import PluginRegistry, PluginEnabler from .mimebundle import spec_to_mimebundle from .schemapi import validate_jsonschema @@ -12,8 +12,19 @@ # ============================================================================== # Renderer registry # ============================================================================== -MimeBundleType = Dict[str, object] +# MimeBundleType needs to be the same as what are acceptable return values +# for _repr_mimebundle_, +# see https://ipython.readthedocs.io/en/stable/config/integrating.html#MyObject._repr_mimebundle_ +MimeBundleDataType = Dict[str, Any] +MimeBundleMetaDataType = Dict[str, Any] +MimeBundleType = Union[ + MimeBundleDataType, Tuple[MimeBundleDataType, MimeBundleMetaDataType] +] RendererType = Callable[..., MimeBundleType] +# Subtype of MimeBundleType as more specific in the values of the dictionaries +DefaultRendererReturnType = Tuple[ + Dict[str, Union[str, dict]], Dict[str, Dict[str, Any]] +] class RendererRegistry(PluginRegistry[RendererType]): @@ -37,15 +48,15 @@ class RendererRegistry(PluginRegistry[RendererType]): def set_embed_options( self, - defaultStyle=None, - renderer=None, - width=None, - height=None, - padding=None, - scaleFactor=None, - actions=None, + defaultStyle: Optional[Union[bool, str]] = None, + renderer: Optional[str] = None, + width: Optional[int] = None, + height: Optional[int] = None, + padding: Optional[int] = None, + scaleFactor: Optional[float] = None, + actions: Optional[Union[bool, Dict[str, bool]]] = None, **kwargs, - ): + ) -> PluginEnabler: """Set options for embeddings of Vega & Vega-Lite charts. Options are fully documented at https://github.com/vega/vega-embed. @@ -79,7 +90,7 @@ def set_embed_options( **kwargs : Additional options are passed directly to embed options. """ - options = { + options: Dict[str, Optional[Union[bool, str, float, Dict[str, bool]]]] = { "defaultStyle": defaultStyle, "renderer": renderer, "width": width, @@ -115,40 +126,44 @@ class Displayable: renderers: Optional[RendererRegistry] = None schema_path = ("altair", "") - def __init__(self, spec, validate=False): - # type: (dict, bool) -> None + def __init__(self, spec: dict, validate: bool = False) -> None: self.spec = spec self.validate = validate self._validate() - def _validate(self): - # type: () -> None + def _validate(self) -> None: """Validate the spec against the schema.""" data = pkgutil.get_data(*self.schema_path) assert data is not None - schema_dict = json.loads(data.decode("utf-8")) + schema_dict: dict = json.loads(data.decode("utf-8")) validate_jsonschema( self.spec, schema_dict, ) - def _repr_mimebundle_(self, include=None, exclude=None): + def _repr_mimebundle_( + self, include: Any = None, exclude: Any = None + ) -> MimeBundleType: """Return a MIME bundle for display in Jupyter frontends.""" if self.renderers is not None: - return self.renderers.get()(self.spec) + renderer_func = self.renderers.get() + assert renderer_func is not None + return renderer_func(self.spec) else: return {} -def default_renderer_base(spec, mime_type, str_repr, **options): +def default_renderer_base( + spec: dict, mime_type: str, str_repr: str, **options +) -> DefaultRendererReturnType: """A default renderer for Vega or VegaLite that works for modern frontends. This renderer works with modern frontends (JupyterLab, nteract) that know how to render the custom VegaLite MIME type listed above. """ assert isinstance(spec, dict) - bundle = {} - metadata = {} + bundle: Dict[str, Union[str, dict]] = {} + metadata: Dict[str, Dict[str, Any]] = {} bundle[mime_type] = spec bundle["text/plain"] = str_repr @@ -157,7 +172,9 @@ def default_renderer_base(spec, mime_type, str_repr, **options): return bundle, metadata -def json_renderer_base(spec, str_repr, **options): +def json_renderer_base( + spec: dict, str_repr: str, **options +) -> DefaultRendererReturnType: """A renderer that returns a MIME type of application/json. In JupyterLab/nteract this is rendered as a nice JSON tree. @@ -170,15 +187,15 @@ def json_renderer_base(spec, str_repr, **options): class HTMLRenderer: """Object to render charts as HTML, with a unique output div each time""" - def __init__(self, output_div="altair-viz-{}", **kwargs): + def __init__(self, output_div: str = "altair-viz-{}", **kwargs) -> None: self._output_div = output_div self.kwargs = kwargs @property - def output_div(self): + def output_div(self) -> str: return self._output_div.format(uuid.uuid4().hex) - def __call__(self, spec, **metadata): + def __call__(self, spec: dict, **metadata) -> Dict[str, str]: kwargs = self.kwargs.copy() kwargs.update(metadata) return spec_to_mimebundle( diff --git a/altair/vegalite/display.py b/altair/vegalite/display.py index 91c5f33e0..1f3a13b46 100644 --- a/altair/vegalite/display.py +++ b/altair/vegalite/display.py @@ -1,4 +1,9 @@ -from ..utils.display import Displayable, default_renderer_base, json_renderer_base +from ..utils.display import ( + Displayable, + default_renderer_base, + json_renderer_base, + DefaultRendererReturnType, +) from ..utils.display import RendererRegistry, HTMLRenderer @@ -8,4 +13,5 @@ "json_renderer_base", "RendererRegistry", "HTMLRenderer", + "DefaultRendererReturnType", ) diff --git a/altair/vegalite/v5/display.py b/altair/vegalite/v5/display.py index b86a4c936..34ca625e9 100644 --- a/altair/vegalite/v5/display.py +++ b/altair/vegalite/v5/display.py @@ -1,11 +1,15 @@ import os +from typing import Dict from ...utils.mimebundle import spec_to_mimebundle -from ..display import Displayable -from ..display import default_renderer_base -from ..display import json_renderer_base -from ..display import RendererRegistry -from ..display import HTMLRenderer +from ..display import ( + Displayable, + default_renderer_base, + json_renderer_base, + RendererRegistry, + HTMLRenderer, + DefaultRendererReturnType, +) from .schema import SCHEMA_VERSION @@ -20,12 +24,12 @@ # The MIME type for Vega-Lite 5.x releases. -VEGALITE_MIME_TYPE = "application/vnd.vegalite.v5+json" # type: str +VEGALITE_MIME_TYPE = "application/vnd.vegalite.v5+json" # The entry point group that can be used by other packages to declare other # renderers that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.renderer" # type: str +ENTRY_POINT_GROUP = "altair.vegalite.v5.renderer" # The display message when rendering fails DEFAULT_DISPLAY = """\ @@ -41,15 +45,15 @@ here = os.path.dirname(os.path.realpath(__file__)) -def mimetype_renderer(spec, **metadata): +def mimetype_renderer(spec: dict, **metadata) -> DefaultRendererReturnType: return default_renderer_base(spec, VEGALITE_MIME_TYPE, DEFAULT_DISPLAY, **metadata) -def json_renderer(spec, **metadata): +def json_renderer(spec: dict, **metadata) -> DefaultRendererReturnType: return json_renderer_base(spec, DEFAULT_DISPLAY, **metadata) -def png_renderer(spec, **metadata): +def png_renderer(spec: dict, **metadata) -> Dict[str, bytes]: return spec_to_mimebundle( spec, format="png", @@ -61,7 +65,7 @@ def png_renderer(spec, **metadata): ) -def svg_renderer(spec, **metadata): +def svg_renderer(spec: dict, **metadata) -> Dict[str, str]: return spec_to_mimebundle( spec, format="svg", @@ -102,7 +106,7 @@ class VegaLite(Displayable): schema_path = (__name__, "schema/vega-lite-schema.json") -def vegalite(spec, validate=True): +def vegalite(spec: dict, validate: bool = True) -> None: """Render and optionally validate a VegaLite 5 spec. This will use the currently enabled renderer to render the spec. diff --git a/altair/vegalite/v5/theme.py b/altair/vegalite/v5/theme.py index 5bd6dcecf..6b14fe6d9 100644 --- a/altair/vegalite/v5/theme.py +++ b/altair/vegalite/v5/theme.py @@ -34,7 +34,7 @@ def __repr__(self) -> str: # The entry point group that can be used by other packages to declare other -# renderers that will be auto-detected. Explicit registration is also +# themes that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. ENTRY_POINT_GROUP = "altair.vegalite.v5.theme" themes = ThemeRegistry(entry_point_group=ENTRY_POINT_GROUP) From e748ce0b1e23a5465f773795b28195eb51e214cf Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 20:04:20 +0000 Subject: [PATCH 03/17] Type hint data modules and some related files --- altair/utils/core.py | 44 +++++++-------- altair/utils/data.py | 108 +++++++++++++++++++++++++++---------- altair/vegalite/data.py | 5 +- altair/vegalite/v5/data.py | 4 +- requirements_dev.txt | 3 +- 5 files changed, 109 insertions(+), 55 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 94033cd31..b8d9c9fcf 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -1,7 +1,7 @@ """ Utility routines """ -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import deepcopy import json import itertools @@ -250,7 +250,7 @@ def merge_props_geom(feat): return props_geom -def sanitize_geo_interface(geo): +def sanitize_geo_interface(geo: MutableMapping) -> dict: """Santize a geo_interface to prepare it for serialization. * Make a copy @@ -267,23 +267,23 @@ def sanitize_geo_interface(geo): geo[key] = geo[key].tolist() # convert (nested) tuples to lists - geo = json.loads(json.dumps(geo)) + geo_dct: dict = json.loads(json.dumps(geo)) # sanitize features - if geo["type"] == "FeatureCollection": - geo = geo["features"] - if len(geo) > 0: - for idx, feat in enumerate(geo): - geo[idx] = merge_props_geom(feat) - elif geo["type"] == "Feature": - geo = merge_props_geom(geo) + if geo_dct["type"] == "FeatureCollection": + geo_dct = geo_dct["features"] + if len(geo_dct) > 0: + for idx, feat in enumerate(geo_dct): + geo_dct[idx] = merge_props_geom(feat) + elif geo_dct["type"] == "Feature": + geo_dct = merge_props_geom(geo_dct) else: - geo = {"type": "Feature", "geometry": geo} + geo_dct = {"type": "Feature", "geometry": geo_dct} - return geo + return geo_dct -def sanitize_dataframe(df): # noqa: C901 +def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901 """Sanitize a DataFrame to prepare it for serialization. * Make a copy @@ -328,20 +328,20 @@ def to_list_if_array(val): # Work around bug in to_json for categorical types in older versions of pandas # https://github.com/pydata/pandas/issues/10778 # https://github.com/altair-viz/altair/pull/2170 - col = df[col_name].astype(object) + col = df[col_name].astype(object) # type: ignore[call-overload] df[col_name] = col.where(col.notnull(), None) elif str(dtype) == "string": # dedicated string datatype (since 1.0) # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type - col = df[col_name].astype(object) + col = df[col_name].astype(object) # type: ignore[call-overload] df[col_name] = col.where(col.notnull(), None) elif str(dtype) == "bool": # convert numpy bools to objects; np.bool is not JSON serializable - df[col_name] = df[col_name].astype(object) + df[col_name] = df[col_name].astype(object) # type: ignore[call-overload] elif str(dtype) == "boolean": # dedicated boolean datatype (since 1.0) # https://pandas.io/docs/user_guide/boolean.html - col = df[col_name].astype(object) + col = df[col_name].astype(object) # type: ignore[call-overload] df[col_name] = col.where(col.notnull(), None) elif str(dtype).startswith("datetime"): # Convert datetimes to strings. This needs to be a full ISO string @@ -351,7 +351,7 @@ def to_list_if_array(val): # Vega-Lite are displayed in local time by default. # (see https://github.com/altair-viz/altair/issues/1027) df[col_name] = ( - df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") + df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") # type: ignore[call-overload] ) elif str(dtype).startswith("timedelta"): raise ValueError( @@ -377,21 +377,21 @@ def to_list_if_array(val): "Float64", }: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support - col = df[col_name].astype(object) + col = df[col_name].astype(object) # type: ignore[call-overload] df[col_name] = col.where(col.notnull(), None) elif np.issubdtype(dtype, np.integer): # convert integers to objects; np.int is not JSON serializable - df[col_name] = df[col_name].astype(object) + df[col_name] = df[col_name].astype(object) # type: ignore[call-overload] elif np.issubdtype(dtype, np.floating): # For floats, convert to Python float: np.float is not JSON serializable # Also convert NaN/inf values to null, as they are not JSON serializable - col = df[col_name] + col = df[col_name] # type: ignore[call-overload] bad_values = col.isnull() | np.isinf(col) df[col_name] = col.astype(object).where(~bad_values, None) elif dtype == object: # Convert numpy arrays saved as objects to lists # Arrays are not JSON serializable - col = df[col_name].apply(to_list_if_array, convert_dtype=False) + col = df[col_name].apply(to_list_if_array, convert_dtype=False) # type: ignore[call-overload] df[col_name] = col.where(col.notnull(), None) return df diff --git a/altair/utils/data.py b/altair/utils/data.py index 7990b63e3..f8d7b9e6f 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -2,11 +2,13 @@ import os import random import hashlib +import sys import warnings +from typing import Union, MutableMapping, Optional, Dict, Sequence, TYPE_CHECKING, List import pandas as pd from toolz import curried -from typing import Callable +from typing import Callable, TypeVar from .core import sanitize_dataframe from .core import sanitize_geo_interface @@ -14,6 +16,28 @@ from .plugin_registry import PluginRegistry +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + + +if TYPE_CHECKING: + import pyarrow.lib + + +class SupportsGeoInterface(Protocol): + __geo_interface__: MutableMapping + + +class SupportsDataframe(Protocol): + def __dataframe__(self, *args, **kwargs): + ... + + +DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, SupportsDataframe] +TDataType = TypeVar("TDataType", bound=DataType) + # ============================================================================== # Data transformer registry # ============================================================================== @@ -24,11 +48,11 @@ class DataTransformerRegistry(PluginRegistry[DataTransformerType]): _global_settings = {"consolidate_datasets": True} @property - def consolidate_datasets(self): + def consolidate_datasets(self) -> bool: return self._global_settings["consolidate_datasets"] @consolidate_datasets.setter - def consolidate_datasets(self, value): + def consolidate_datasets(self, value: bool) -> None: self._global_settings["consolidate_datasets"] = value @@ -58,7 +82,7 @@ class MaxRowsError(Exception): @curried.curry -def limit_rows(data, max_rows=5000): +def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType: """Raise MaxRowsError if the data model has more than max_rows. If max_rows is None, then do not perform any check. @@ -75,7 +99,9 @@ def limit_rows(data, max_rows=5000): if "values" in data: values = data["values"] else: - return data + # mypy gets confused as it doesn't see Dict[Any, Any] + # as equivalent to TDataType + return data # type: ignore[return-value] elif hasattr(data, "__dataframe__"): values = data if max_rows is not None and len(values) > max_rows: @@ -91,7 +117,9 @@ def limit_rows(data, max_rows=5000): @curried.curry -def sample(data, n=None, frac=None): +def sample( + data: DataType, n: Optional[int] = None, frac: Optional[float] = None +) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]: """Reduce the size of the data model by sampling without replacement.""" check_data_type(data) if isinstance(data, pd.DataFrame): @@ -99,26 +127,43 @@ def sample(data, n=None, frac=None): elif isinstance(data, dict): if "values" in data: values = data["values"] - n = n if n else int(frac * len(values)) + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None and data is a dictionary" + ) + n = int(frac * len(values)) values = random.sample(values, n) return {"values": values} + else: + # Maybe this should raise an error or return something useful? + return None elif hasattr(data, "__dataframe__"): # experimental interchange dataframe support pi = import_pyarrow_interchange() pa_table = pi.from_dataframe(data) - n = n if n else int(frac * len(pa_table)) + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None with this data input type" + ) + n = int(frac * len(pa_table)) indices = random.sample(range(len(pa_table)), n) return pa_table.take(indices) + else: + # Maybe this should raise an error or return something useful? Currently, + # if data is of type SupportsGeoInterface it lands here + return None @curried.curry def to_json( - data, - prefix="altair-data", - extension="json", - filename="{prefix}-{hash}.{extension}", - urlpath="", -): + data: DataType, + prefix: str = "altair-data", + extension: str = "json", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> Dict[str, Union[str, Dict[str, str]]]: """ Write the data model to a .json file and return a url based data model. """ @@ -132,12 +177,12 @@ def to_json( @curried.curry def to_csv( - data, - prefix="altair-data", - extension="csv", - filename="{prefix}-{hash}.{extension}", - urlpath="", -): + data: Union[dict, pd.DataFrame, SupportsDataframe], + prefix: str = "altair-data", + extension: str = "csv", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> Dict[str, Union[str, Dict[str, str]]]: """Write the data model to a .csv file and return a url based data model.""" data_csv = _data_to_csv_string(data) data_hash = _compute_data_hash(data_csv) @@ -148,14 +193,16 @@ def to_csv( @curried.curry -def to_values(data): +def to_values(data: DataType) -> Optional[Dict[str, Union[dict, List[dict]]]]: """Replace a DataFrame by a data model with values.""" check_data_type(data) if hasattr(data, "__geo_interface__"): if isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) - data = sanitize_geo_interface(data.__geo_interface__) - return {"values": data} + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data_sanitized = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] + return {"values": data_sanitized} elif isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) return {"values": data.to_dict(orient="records")} @@ -168,9 +215,12 @@ def to_values(data): pi = import_pyarrow_interchange() pa_table = pi.from_dataframe(data) return {"values": pa_table.to_pylist()} + else: + # Should this raise an error? + return None -def check_data_type(data): +def check_data_type(data: DataType) -> None: """Raise if the data is not a dict or DataFrame.""" if not isinstance(data, (dict, pd.DataFrame)) and not any( hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] @@ -187,17 +237,19 @@ def check_data_type(data): # ============================================================================== -def _compute_data_hash(data_str): +def _compute_data_hash(data_str: str) -> str: return hashlib.md5(data_str.encode()).hexdigest() -def _data_to_json_string(data): +def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): if isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) - data = sanitize_geo_interface(data.__geo_interface__) + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] return json.dumps(data) elif isinstance(data, pd.DataFrame): data = sanitize_dataframe(data) @@ -217,7 +269,7 @@ def _data_to_json_string(data): ) -def _data_to_csv_string(data): +def _data_to_csv_string(data: Union[dict, pd.DataFrame, SupportsDataframe]) -> str: """return a CSV string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index 30289160b..d44022969 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -12,15 +12,16 @@ check_data_type, ) from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry +from ..utils.plugin_registry import PluginEnabler @curried.curry -def default_data_transformer(data, max_rows=5000): +def default_data_transformer(data, max_rows: int = 5000): return curried.pipe(data, limit_rows(max_rows=max_rows), to_values) class DataTransformerRegistry(_DataTransformerRegistry): - def disable_max_rows(self): + def disable_max_rows(self) -> PluginEnabler: """Disable the MaxRowsError.""" options = self.options if self.active == "default": diff --git a/altair/vegalite/v5/data.py b/altair/vegalite/v5/data.py index 703dffb32..e14ee8136 100644 --- a/altair/vegalite/v5/data.py +++ b/altair/vegalite/v5/data.py @@ -17,12 +17,12 @@ # ============================================================================== -ENTRY_POINT_GROUP = "altair.vegalite.v5.data_transformer" # type: str +ENTRY_POINT_GROUP = "altair.vegalite.v5.data_transformer" data_transformers = DataTransformerRegistry( entry_point_group=ENTRY_POINT_GROUP -) # type: DataTransformerRegistry +) data_transformers.register("default", default_data_transformer) data_transformers.register("json", to_json) data_transformers.register("csv", to_csv) diff --git a/requirements_dev.txt b/requirements_dev.txt index f8ccff938..9ba155203 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -10,4 +10,5 @@ vl-convert-python mypy pandas-stubs types-jsonschema -types-setuptools \ No newline at end of file +types-setuptools +pyarrow \ No newline at end of file From 7feb6ab88f6a57d13b27be7a0adcf63f41c6aaaf Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 20:39:38 +0000 Subject: [PATCH 04/17] Type hint core.py --- altair/utils/core.py | 63 ++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index b8d9c9fcf..183c4af42 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -9,7 +9,8 @@ import sys import traceback import warnings -from typing import Callable, TypeVar, Any +from typing import Callable, TypeVar, Any, Union, Dict, Optional, Tuple, Sequence, Type +from types import ModuleType import jsonschema import pandas as pd @@ -190,7 +191,9 @@ def infer_dtype(value): ] -def infer_vegalite_type(data): +def infer_vegalite_type( + data: Union[np.ndarray, pd.Series] +) -> Union[str, Tuple[str, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') @@ -210,8 +213,10 @@ def infer_vegalite_type(data): "complex", ]: return "quantitative" - elif typ == "categorical" and data.cat.ordered: - return ("ordinal", data.cat.categories.tolist()) + # Can ignore error that np.ndarray has no attribute cat as in this case + # it should always be a pd.DataFrame anyway + elif typ == "categorical" and data.cat.ordered: # type: ignore[union-attr] + return ("ordinal", data.cat.categories.tolist()) # type: ignore[union-attr] elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: return "nominal" elif typ in [ @@ -232,7 +237,7 @@ def infer_vegalite_type(data): return "nominal" -def merge_props_geom(feat): +def merge_props_geom(feat: dict) -> dict: """ Merge properties with geometry * Overwrites 'type' and 'geometry' entries if existing @@ -397,13 +402,13 @@ def to_list_if_array(val): def parse_shorthand( - shorthand, - data=None, - parse_aggregates=True, - parse_window_ops=False, - parse_timeunits=True, - parse_types=True, -): + shorthand: Union[Dict[str, Any], str], + data: Optional[pd.DataFrame] = None, + parse_aggregates: bool = True, + parse_window_ops: bool = False, + parse_timeunits: bool = True, + parse_types: bool = True, +) -> Dict[str, Any]: """General tool to parse shorthand values These are of the form: @@ -518,7 +523,9 @@ def parse_shorthand( attrs = shorthand else: attrs = next( - exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand) + exp.match(shorthand).groupdict() # type: ignore[union-attr] + for exp in regexps + if exp.match(shorthand) is not None ) # Handle short form of the type expression @@ -593,21 +600,23 @@ def decorate(f: Callable[..., _V]) -> Callable[_P, _V]: return decorate -def update_nested(original, update, copy=False): +def update_nested( + original: MutableMapping, update: Mapping, copy: bool = False +) -> MutableMapping: """Update nested dictionaries Parameters ---------- - original : dict + original : MutableMapping the original (nested) dictionary, which will be updated in-place - update : dict + update : Mapping the nested dictionary of updates copy : bool, default False if True, then copy the original dictionary rather than modifying it Returns ------- - original : dict + original : MutableMapping a reference to the (modified) original dict Examples @@ -624,7 +633,7 @@ def update_nested(original, update, copy=False): for key, val in update.items(): if isinstance(val, Mapping): orig_val = original.get(key, {}) - if isinstance(orig_val, Mapping): + if isinstance(orig_val, MutableMapping): original[key] = update_nested(orig_val, val) else: original[key] = val @@ -633,7 +642,7 @@ def update_nested(original, update, copy=False): return original -def display_traceback(in_ipython=True): +def display_traceback(in_ipython: bool = True): exc_info = sys.exc_info() if in_ipython: @@ -649,16 +658,16 @@ def display_traceback(in_ipython=True): traceback.print_exception(*exc_info) -def infer_encoding_types(args, kwargs, channels): +def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): """Infer typed keyword arguments for args and kwargs Parameters ---------- - args : tuple - List of function args - kwargs : dict + args : Sequence + Sequence of function args + kwargs : MutableMapping Dict of function kwargs - channels : module + channels : ModuleType The module containing all altair encoding channel classes. Returns @@ -673,8 +682,10 @@ def infer_encoding_types(args, kwargs, channels): channel_objs = ( c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) ) - channel_to_name = {c: c._encoding_name for c in channel_objs} - name_to_channel = {} + channel_to_name: Dict[Type[SchemaBase], str] = { + c: c._encoding_name for c in channel_objs + } + name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} for chan, name in channel_to_name.items(): chans = name_to_channel.setdefault(name, {}) if chan.__name__.endswith("Datum"): From dd5694ad192705fe789f9dd0eda00113a8cc8ac9 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 20:44:55 +0000 Subject: [PATCH 05/17] Type hint html.py --- altair/utils/html.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/altair/utils/html.py b/altair/utils/html.py index 6be18139c..b3dff16b7 100644 --- a/altair/utils/html.py +++ b/altair/utils/html.py @@ -1,4 +1,6 @@ import json +from typing import Optional, Dict + import jinja2 @@ -173,7 +175,7 @@ ) -TEMPLATES = { +TEMPLATES: Dict[str, jinja2.Template] = { "standard": HTML_TEMPLATE, "universal": HTML_TEMPLATE_UNIVERSAL, "inline": INLINE_HTML_TEMPLATE, @@ -181,19 +183,19 @@ def spec_to_html( - spec, - mode, - vega_version, - vegaembed_version, - vegalite_version=None, - base_url="https://cdn.jsdelivr.net/npm", - output_div="vis", - embed_options=None, - json_kwds=None, - fullhtml=True, - requirejs=False, - template="standard", -): + spec: dict, + mode: str, + vega_version: str, + vegaembed_version: str, + vegalite_version: Optional[str] = None, + base_url: str = "https://cdn.jsdelivr.net/npm", + output_div: str = "vis", + embed_options: Optional[dict] = None, + json_kwds: Optional[dict] = None, + fullhtml: bool = True, + requirejs: bool = False, + template: str = "standard", +) -> str: """Embed a Vega/Vega-Lite spec into an HTML page Parameters @@ -267,11 +269,11 @@ def spec_to_html( "vega-embed", vegaembed_version ) - template = TEMPLATES.get(template, template) - if not hasattr(template, "render"): - raise ValueError("Invalid template: {0}".format(template)) + jinja_template = TEMPLATES.get(template, template) + if not hasattr(jinja_template, "render"): + raise ValueError("Invalid template: {0}".format(jinja_template)) - return template.render( + return jinja_template.render( spec=json.dumps(spec, **json_kwds), embed_options=json.dumps(embed_options), mode=mode, From e1693e5bef1bb6b02a635930f8efb6732dca58ac Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 18 Mar 2023 20:47:31 +0000 Subject: [PATCH 06/17] Format code --- altair/vegalite/v5/data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/altair/vegalite/v5/data.py b/altair/vegalite/v5/data.py index e14ee8136..1a50d2bd9 100644 --- a/altair/vegalite/v5/data.py +++ b/altair/vegalite/v5/data.py @@ -20,9 +20,7 @@ ENTRY_POINT_GROUP = "altair.vegalite.v5.data_transformer" -data_transformers = DataTransformerRegistry( - entry_point_group=ENTRY_POINT_GROUP -) +data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP) data_transformers.register("default", default_data_transformer) data_transformers.register("json", to_json) data_transformers.register("csv", to_csv) From 3f2c22096e1063871c11eeb044da9f40247c7c5a Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 10 Jun 2023 12:42:18 +0000 Subject: [PATCH 07/17] Type hint default_data_transformer --- altair/utils/data.py | 3 ++- altair/vegalite/data.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index ab95ea0ef..dba986dc3 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -37,6 +37,7 @@ def __dataframe__(self, *args, **kwargs): DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, SupportsDataframe] TDataType = TypeVar("TDataType", bound=DataType) +ToValuesReturnType = Optional[Dict[str, Union[dict, List[dict]]]] # ============================================================================== # Data transformer registry @@ -193,7 +194,7 @@ def to_csv( @curried.curry -def to_values(data: DataType) -> Optional[Dict[str, Union[dict, List[dict]]]]: +def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) if hasattr(data, "__geo_interface__"): diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index d44022969..12a14a983 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -1,3 +1,5 @@ +from typing import Optional, Dict, Union, List + from toolz import curried from ..utils.core import sanitize_dataframe from ..utils.data import ( @@ -12,11 +14,14 @@ check_data_type, ) from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry +from ..utils.data import DataType, ToValuesReturnType from ..utils.plugin_registry import PluginEnabler @curried.curry -def default_data_transformer(data, max_rows: int = 5000): +def default_data_transformer( + data: DataType, max_rows: int = 5000 +) -> ToValuesReturnType: return curried.pipe(data, limit_rows(max_rows=max_rows), to_values) From cb2a8a8fdfc6a284fb48e64ebbac91094c0ac518 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 10 Jun 2023 13:05:48 +0000 Subject: [PATCH 08/17] Improve type hints on data transformers --- altair/utils/data.py | 71 ++++++++++++++++++++++++----------------- altair/vegalite/data.py | 2 -- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index dba986dc3..bcdea935c 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -5,10 +5,11 @@ import sys import warnings from typing import Union, MutableMapping, Optional, Dict, Sequence, TYPE_CHECKING, List +from types import ModuleType import pandas as pd from toolz import curried -from typing import Callable, TypeVar +from typing import TypeVar from .core import sanitize_dataframe, sanitize_arrow_table from .core import sanitize_geo_interface @@ -17,9 +18,9 @@ if sys.version_info >= (3, 8): - from typing import Protocol + from typing import Protocol, TypedDict, Literal else: - from typing_extensions import Protocol + from typing_extensions import Protocol, TypedDict, Literal if TYPE_CHECKING: @@ -37,12 +38,25 @@ def __dataframe__(self, *args, **kwargs): DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, SupportsDataframe] TDataType = TypeVar("TDataType", bound=DataType) -ToValuesReturnType = Optional[Dict[str, Union[dict, List[dict]]]] + +VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] +ToValuesReturnType = Dict[str, Union[dict, List[dict]]] + # ============================================================================== # Data transformer registry +# +# A data transformer is a callable that takes a supported data type and returns +# a transformed dictionary version of it which is compatible with the VegaLite schema. +# The dict objects will be the Data portion of the VegaLite schema. +# +# Renderers only deal with the dict form of a +# VegaLite spec, after the Data model has been put into a schema compliant +# form. # ============================================================================== -DataTransformerType = Callable +class DataTransformerType(Protocol): + def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: + pass class DataTransformerRegistry(PluginRegistry[DataTransformerType]): @@ -58,24 +72,6 @@ def consolidate_datasets(self, value: bool) -> None: # ============================================================================== -# Data model transformers -# -# A data model transformer is a pure function that takes a dict or DataFrame -# and returns a transformed version of a dict or DataFrame. The dict objects -# will be the Data portion of the VegaLite schema. The idea is that user can -# pipe a sequence of these data transformers together to prepare the data before -# it hits the renderer. -# -# In this version of Altair, renderers only deal with the dict form of a -# VegaLite spec, after the Data model has been put into a schema compliant -# form. -# -# A data model transformer has the following type signature: -# DataModelType = Union[dict, pd.DataFrame] -# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType] -# ============================================================================== - - class MaxRowsError(Exception): """Raised when a data model has too many rows.""" @@ -157,6 +153,24 @@ def sample( return None +class _JsonFormatDict(TypedDict): + type: Literal["json"] + + +class _CsvFormatDict(TypedDict): + type: Literal["csv"] + + +class _DataJsonUrlDict(TypedDict): + url: str + format: _JsonFormatDict + + +class _DataCsvUrlDict(TypedDict): + url: str + format: _CsvFormatDict + + @curried.curry def to_json( data: DataType, @@ -164,7 +178,7 @@ def to_json( extension: str = "json", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", -) -> Dict[str, Union[str, Dict[str, str]]]: +) -> _DataJsonUrlDict: """ Write the data model to a .json file and return a url based data model. """ @@ -183,7 +197,7 @@ def to_csv( extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", -) -> Dict[str, Union[str, Dict[str, str]]]: +) -> _DataCsvUrlDict: """Write the data model to a .csv file and return a url based data model.""" data_csv = _data_to_csv_string(data) data_hash = _compute_data_hash(data_csv) @@ -217,12 +231,11 @@ def to_values(data: DataType) -> ToValuesReturnType: pa_table = sanitize_arrow_table(pi.from_dataframe(data)) return {"values": pa_table.to_pylist()} else: - # Should this raise an error? - return None + # Should never reach this state as tested by check_data_type + raise ValueError("Unrecognized data type: {}".format(type(data))) def check_data_type(data: DataType) -> None: - """Raise if the data is not a dict or DataFrame.""" if not isinstance(data, (dict, pd.DataFrame)) and not any( hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] ): @@ -328,7 +341,7 @@ def curry(*args, **kwargs): return curried.curry(*args, **kwargs) -def import_pyarrow_interchange(): +def import_pyarrow_interchange() -> ModuleType: import pkg_resources try: diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index 12a14a983..fdc9fe3c2 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -1,5 +1,3 @@ -from typing import Optional, Dict, Union, List - from toolz import curried from ..utils.core import sanitize_dataframe from ..utils.data import ( From 570a358f9fd061e69d53090ef666584fbd26fcbb Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 10 Jun 2023 13:16:28 +0000 Subject: [PATCH 09/17] Ignore mypy missing import error for nbformat and ipykernel. Unclear why mypy suddenly complains about those... --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d17777c24..4f781909c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,7 +228,9 @@ module = [ "pyarrow.*", "yaml.*", "vl_convert.*", - "pandas.lib.*" + "pandas.lib.*", + "nbformat.*", + "ipykernel.*" ] ignore_missing_imports = true From a166ee06d05c3579e4cc5429122e8d1c2aeed2f3 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 10 Jun 2023 13:20:22 +0000 Subject: [PATCH 10/17] Use Final for constants --- altair/vegalite/v5/compiler.py | 10 +++++++++- altair/vegalite/v5/data.py | 9 ++++++++- altair/vegalite/v5/display.py | 18 ++++++++++++------ altair/vegalite/v5/theme.py | 8 +++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/altair/vegalite/v5/compiler.py b/altair/vegalite/v5/compiler.py index 60afa99f0..a4e02f79f 100644 --- a/altair/vegalite/v5/compiler.py +++ b/altair/vegalite/v5/compiler.py @@ -1,6 +1,14 @@ +import sys + from ...utils.compiler import VegaLiteCompilerRegistry -ENTRY_POINT_GROUP: str = "altair.vegalite.v5.vegalite_compiler" +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + + +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.vegalite_compiler" vegalite_compilers = VegaLiteCompilerRegistry(entry_point_group=ENTRY_POINT_GROUP) diff --git a/altair/vegalite/v5/data.py b/altair/vegalite/v5/data.py index 1a50d2bd9..af7af4d04 100644 --- a/altair/vegalite/v5/data.py +++ b/altair/vegalite/v5/data.py @@ -1,3 +1,5 @@ +import sys + from ..data import ( MaxRowsError, curry, @@ -11,13 +13,18 @@ DataTransformerRegistry, ) +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + # ============================================================================== # VegaLite 5 data transformers # ============================================================================== -ENTRY_POINT_GROUP = "altair.vegalite.v5.data_transformer" +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.data_transformer" data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP) diff --git a/altair/vegalite/v5/display.py b/altair/vegalite/v5/display.py index 087b896d8..d22b938a3 100644 --- a/altair/vegalite/v5/display.py +++ b/altair/vegalite/v5/display.py @@ -1,4 +1,5 @@ import os +import sys from typing import Dict from ...utils.mimebundle import spec_to_mimebundle @@ -13,9 +14,14 @@ from .schema import SCHEMA_VERSION -VEGALITE_VERSION = SCHEMA_VERSION.lstrip("v") -VEGA_VERSION = "5" -VEGAEMBED_VERSION = "6" +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + +VEGALITE_VERSION: Final = SCHEMA_VERSION.lstrip("v") +VEGA_VERSION: Final = "5" +VEGAEMBED_VERSION: Final = "6" # ============================================================================== @@ -24,15 +30,15 @@ # The MIME type for Vega-Lite 5.x releases. -VEGALITE_MIME_TYPE = "application/vnd.vegalite.v5+json" +VEGALITE_MIME_TYPE: Final = "application/vnd.vegalite.v5+json" # The entry point group that can be used by other packages to declare other # renderers that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.renderer" +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.renderer" # The display message when rendering fails -DEFAULT_DISPLAY = """\ +DEFAULT_DISPLAY: Final = """\ If you see this message, it means the renderer has not been properly enabled diff --git a/altair/vegalite/v5/theme.py b/altair/vegalite/v5/theme.py index 6b14fe6d9..cdaca2dbf 100644 --- a/altair/vegalite/v5/theme.py +++ b/altair/vegalite/v5/theme.py @@ -1,4 +1,5 @@ """Tools for enabling and registering chart themes""" +import sys from typing import Dict, Union from ...utils.theme import ThemeRegistry @@ -16,6 +17,11 @@ "powerbi", ] +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + class VegaTheme: """Implementation of a builtin vega theme.""" @@ -36,7 +42,7 @@ def __repr__(self) -> str: # The entry point group that can be used by other packages to declare other # themes that will be auto-detected. Explicit registration is also # allowed by the PluginRegistery API. -ENTRY_POINT_GROUP = "altair.vegalite.v5.theme" +ENTRY_POINT_GROUP: Final = "altair.vegalite.v5.theme" themes = ThemeRegistry(entry_point_group=ENTRY_POINT_GROUP) themes.register( From d34d07d223a8d69b0b626348221d1236b432914d Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 10 Jun 2023 13:49:38 +0000 Subject: [PATCH 11/17] Mark some type aliases/protocols as private which are new and where we might not yet want users to rely on them. Gives more flexibility in the future to change them --- altair/utils/core.py | 10 +++++++++- altair/utils/data.py | 30 +++++++++++++++--------------- altair/vegalite/data.py | 6 +++--- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 9d55dc9ba..10fd47005 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -23,6 +23,11 @@ else: from typing_extensions import ParamSpec +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + try: from pandas.api.types import infer_dtype as _infer_dtype except ImportError: @@ -191,9 +196,12 @@ def infer_dtype(value): ] +InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] + + def infer_vegalite_type( data: Union[np.ndarray, pd.Series] -) -> Union[str, Tuple[str, list]]: +) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') diff --git a/altair/utils/data.py b/altair/utils/data.py index bcdea935c..e48f57242 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -27,20 +27,20 @@ import pyarrow.lib -class SupportsGeoInterface(Protocol): +class _SupportsGeoInterface(Protocol): __geo_interface__: MutableMapping -class SupportsDataframe(Protocol): +class _SupportsDataframe(Protocol): def __dataframe__(self, *args, **kwargs): ... -DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, SupportsDataframe] -TDataType = TypeVar("TDataType", bound=DataType) +_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _SupportsDataframe] +_TDataType = TypeVar("_TDataType", bound=_DataType) -VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] -ToValuesReturnType = Dict[str, Union[dict, List[dict]]] +_VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] +_ToValuesReturnType = Dict[str, Union[dict, List[dict]]] # ============================================================================== @@ -55,7 +55,7 @@ def __dataframe__(self, *args, **kwargs): # form. # ============================================================================== class DataTransformerType(Protocol): - def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: + def __call__(self, data: _DataType, **kwargs) -> _VegaLiteDataDict: pass @@ -79,7 +79,7 @@ class MaxRowsError(Exception): @curried.curry -def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType: +def limit_rows(data: _TDataType, max_rows: Optional[int] = 5000) -> _TDataType: """Raise MaxRowsError if the data model has more than max_rows. If max_rows is None, then do not perform any check. @@ -115,7 +115,7 @@ def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType: @curried.curry def sample( - data: DataType, n: Optional[int] = None, frac: Optional[float] = None + data: _DataType, n: Optional[int] = None, frac: Optional[float] = None ) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]: """Reduce the size of the data model by sampling without replacement.""" check_data_type(data) @@ -173,7 +173,7 @@ class _DataCsvUrlDict(TypedDict): @curried.curry def to_json( - data: DataType, + data: _DataType, prefix: str = "altair-data", extension: str = "json", filename: str = "{prefix}-{hash}.{extension}", @@ -192,7 +192,7 @@ def to_json( @curried.curry def to_csv( - data: Union[dict, pd.DataFrame, SupportsDataframe], + data: Union[dict, pd.DataFrame, _SupportsDataframe], prefix: str = "altair-data", extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", @@ -208,7 +208,7 @@ def to_csv( @curried.curry -def to_values(data: DataType) -> ToValuesReturnType: +def to_values(data: _DataType) -> _ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) if hasattr(data, "__geo_interface__"): @@ -235,7 +235,7 @@ def to_values(data: DataType) -> ToValuesReturnType: raise ValueError("Unrecognized data type: {}".format(type(data))) -def check_data_type(data: DataType) -> None: +def check_data_type(data: _DataType) -> None: if not isinstance(data, (dict, pd.DataFrame)) and not any( hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] ): @@ -253,7 +253,7 @@ def _compute_data_hash(data_str: str) -> str: return hashlib.md5(data_str.encode()).hexdigest() -def _data_to_json_string(data: DataType) -> str: +def _data_to_json_string(data: _DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): @@ -281,7 +281,7 @@ def _data_to_json_string(data: DataType) -> str: ) -def _data_to_csv_string(data: Union[dict, pd.DataFrame, SupportsDataframe]) -> str: +def _data_to_csv_string(data: Union[dict, pd.DataFrame, _SupportsDataframe]) -> str: """return a CSV string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index fdc9fe3c2..3aca8ea5b 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -12,14 +12,14 @@ check_data_type, ) from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry -from ..utils.data import DataType, ToValuesReturnType +from ..utils.data import _DataType, _ToValuesReturnType from ..utils.plugin_registry import PluginEnabler @curried.curry def default_data_transformer( - data: DataType, max_rows: int = 5000 -) -> ToValuesReturnType: + data: _DataType, max_rows: int = 5000 +) -> _ToValuesReturnType: return curried.pipe(data, limit_rows(max_rows=max_rows), to_values) From 4ef8f93256c4d733505b12f1ba82955cc7dea091 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 16 Jun 2023 17:12:38 +0000 Subject: [PATCH 12/17] Remove unused 'type ignore' statements. Add mypy error if it detects such a redundant statement --- altair/utils/core.py | 18 +++++++++--------- altair/utils/schemapi.py | 2 +- altair/vegalite/v5/api.py | 4 ++-- pyproject.toml | 3 +++ tools/schemapi/schemapi.py | 2 +- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 95b7d0603..4f5f61487 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -347,20 +347,20 @@ def to_list_if_array(val): # Work around bug in to_json for categorical types in older versions of pandas # https://github.com/pydata/pandas/issues/10778 # https://github.com/altair-viz/altair/pull/2170 - col = df[col_name].astype(object) # type: ignore[call-overload] + col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif str(dtype) == "string": # dedicated string datatype (since 1.0) # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type - col = df[col_name].astype(object) # type: ignore[call-overload] + col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif str(dtype) == "bool": # convert numpy bools to objects; np.bool is not JSON serializable - df[col_name] = df[col_name].astype(object) # type: ignore[call-overload] + df[col_name] = df[col_name].astype(object) elif str(dtype) == "boolean": # dedicated boolean datatype (since 1.0) # https://pandas.io/docs/user_guide/boolean.html - col = df[col_name].astype(object) # type: ignore[call-overload] + col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif str(dtype).startswith("datetime"): # Convert datetimes to strings. This needs to be a full ISO string @@ -370,7 +370,7 @@ def to_list_if_array(val): # Vega-Lite are displayed in local time by default. # (see https://github.com/altair-viz/altair/issues/1027) df[col_name] = ( - df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") # type: ignore[call-overload] + df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") ) elif str(dtype).startswith("timedelta"): raise ValueError( @@ -396,21 +396,21 @@ def to_list_if_array(val): "Float64", }: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support - col = df[col_name].astype(object) # type: ignore[call-overload] + col = df[col_name].astype(object) df[col_name] = col.where(col.notnull(), None) elif np.issubdtype(dtype, np.integer): # convert integers to objects; np.int is not JSON serializable - df[col_name] = df[col_name].astype(object) # type: ignore[call-overload] + df[col_name] = df[col_name].astype(object) elif np.issubdtype(dtype, np.floating): # For floats, convert to Python float: np.float is not JSON serializable # Also convert NaN/inf values to null, as they are not JSON serializable - col = df[col_name] # type: ignore[call-overload] + col = df[col_name] bad_values = col.isnull() | np.isinf(col) df[col_name] = col.astype(object).where(~bad_values, None) elif dtype == object: # Convert numpy arrays saved as objects to lists # Arrays are not JSON serializable - col = df[col_name].apply(to_list_if_array, convert_dtype=False) # type: ignore[call-overload] + col = df[col_name].apply(to_list_if_array, convert_dtype=False) df[col_name] = col.where(col.notnull(), None) return df diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 84351b9b0..75568591d 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -461,7 +461,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: """Format param names into a table so that they are easier to read""" param_names: Tuple[str, ...] name_lengths: Tuple[int, ...] - param_names, name_lengths = zip( # type: ignore[assignment] # Mypy does think it's Tuple[Any] + param_names, name_lengths = zip( *[ (name, len(name)) for name in param_dict_keys diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 9e97e3dff..cb81e565c 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2369,7 +2369,7 @@ def show(self, embed_opt=None, open_browser=None): connected to a browser. """ try: - import altair_viewer # type: ignore + import altair_viewer except ImportError as err: raise ValueError( "'show' method requires the altair_viewer package. " @@ -2468,7 +2468,7 @@ def facet( ) # Remove "ignore" statement once Undefined is no longer typed as Any - if data is Undefined: # type: ignore + if data is Undefined: # Remove "ignore" statement once Undefined is no longer typed as Any if self.data is Undefined: # type: ignore raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index eafdc2621..e191ddbac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,9 @@ markers = [ "save_engine: marks some of the tests which are using an external package to save a chart to e.g. a png file. This mark is used to run those tests selectively in the build GitHub Action.", ] +[tool.mypy] +warn_unused_ignores = true + [[tool.mypy.overrides]] module = [ "vega_datasets.*", diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 7809acafc..5a43f574d 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -459,7 +459,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: """Format param names into a table so that they are easier to read""" param_names: Tuple[str, ...] name_lengths: Tuple[int, ...] - param_names, name_lengths = zip( # type: ignore[assignment] # Mypy does think it's Tuple[Any] + param_names, name_lengths = zip( *[ (name, len(name)) for name in param_dict_keys From 12615111a971401137985e37f5d15a37b642aec1 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 16 Jun 2023 17:16:09 +0000 Subject: [PATCH 13/17] Replace _SupportsDataFrame with _DataFrameLike --- altair/utils/data.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index e48f57242..756b0d7d0 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -11,7 +11,7 @@ from toolz import curried from typing import TypeVar -from .core import sanitize_dataframe, sanitize_arrow_table +from .core import sanitize_dataframe, sanitize_arrow_table, _DataFrameLike from .core import sanitize_geo_interface from .deprecation import AltairDeprecationWarning from .plugin_registry import PluginRegistry @@ -31,12 +31,8 @@ class _SupportsGeoInterface(Protocol): __geo_interface__: MutableMapping -class _SupportsDataframe(Protocol): - def __dataframe__(self, *args, **kwargs): - ... - -_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _SupportsDataframe] +_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike] _TDataType = TypeVar("_TDataType", bound=_DataType) _VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] @@ -192,7 +188,7 @@ def to_json( @curried.curry def to_csv( - data: Union[dict, pd.DataFrame, _SupportsDataframe], + data: Union[dict, pd.DataFrame, _DataFrameLike], prefix: str = "altair-data", extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", @@ -281,7 +277,7 @@ def _data_to_json_string(data: _DataType) -> str: ) -def _data_to_csv_string(data: Union[dict, pd.DataFrame, _SupportsDataframe]) -> str: +def _data_to_csv_string(data: Union[dict, pd.DataFrame, _DataFrameLike]) -> str: """return a CSV string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): From fa9c0810242064f75cc27e192221516849eecf5b Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 16 Jun 2023 17:25:04 +0000 Subject: [PATCH 14/17] Format code --- altair/utils/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index 756b0d7d0..8c57cefd6 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -31,7 +31,6 @@ class _SupportsGeoInterface(Protocol): __geo_interface__: MutableMapping - _DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike] _TDataType = TypeVar("_TDataType", bound=_DataType) From 1b0d53198186cf3fefb491fc799ff3bbcfb92a02 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 16 Jun 2023 17:27:59 +0000 Subject: [PATCH 15/17] Rename two typeddicts to make it clearer that they reference the return type of functions --- altair/utils/data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index 8c57cefd6..9355dbef7 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -156,12 +156,12 @@ class _CsvFormatDict(TypedDict): type: Literal["csv"] -class _DataJsonUrlDict(TypedDict): +class _ToJsonReturnUrlDict(TypedDict): url: str format: _JsonFormatDict -class _DataCsvUrlDict(TypedDict): +class _ToCsvReturnUrlDict(TypedDict): url: str format: _CsvFormatDict @@ -173,7 +173,7 @@ def to_json( extension: str = "json", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", -) -> _DataJsonUrlDict: +) -> _ToJsonReturnUrlDict: """ Write the data model to a .json file and return a url based data model. """ @@ -192,7 +192,7 @@ def to_csv( extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", urlpath: str = "", -) -> _DataCsvUrlDict: +) -> _ToCsvReturnUrlDict: """Write the data model to a .csv file and return a url based data model.""" data_csv = _data_to_csv_string(data) data_hash = _compute_data_hash(data_csv) From 15f89fdbb01dba1cb9d18ff0ea285cd41bcbbddf Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 16 Jun 2023 20:45:35 +0000 Subject: [PATCH 16/17] Make InferredVegaLiteType private --- altair/utils/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 4f5f61487..441ad42f4 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -201,12 +201,12 @@ def infer_dtype(value): ] -InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] +_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] def infer_vegalite_type( data: Union[np.ndarray, pd.Series] -) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: +) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') From e974bc96d7ba07dc09f0a1d866c00c30c8b34c07 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 18 Jun 2023 20:51:34 +0200 Subject: [PATCH 17/17] Type hint 'data' in infer_vegalite_type as object as this is the type hint that is used by pandas for their infer_type function --- altair/utils/core.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 441ad42f4..c429a898f 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -43,7 +43,7 @@ def __dataframe__(self, *args, **kwargs): ... -def infer_dtype(value): +def infer_dtype(value: object) -> str: """Infer the dtype of the value. This is a compatibility function for pandas infer_dtype, @@ -54,10 +54,10 @@ def infer_dtype(value): _infer_dtype([1], skipna=False) except TypeError: # pandas < 0.21.0 don't support skipna keyword - infer_dtype._supports_skipna = False + infer_dtype._supports_skipna = False # type: ignore[attr-defined] else: - infer_dtype._supports_skipna = True - if infer_dtype._supports_skipna: + infer_dtype._supports_skipna = True # type: ignore[attr-defined] + if infer_dtype._supports_skipna: # type: ignore[attr-defined] return _infer_dtype(value, skipna=False) else: return _infer_dtype(value) @@ -205,7 +205,7 @@ def infer_dtype(value): def infer_vegalite_type( - data: Union[np.ndarray, pd.Series] + data: object, ) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode @@ -213,9 +213,8 @@ def infer_vegalite_type( Parameters ---------- - data: Numpy array or Pandas Series + data: object """ - # Otherwise, infer based on the dtype of the input typ = infer_dtype(data) if typ in [ @@ -226,10 +225,8 @@ def infer_vegalite_type( "complex", ]: return "quantitative" - # Can ignore error that np.ndarray has no attribute cat as in this case - # it should always be a pd.DataFrame anyway - elif typ == "categorical" and data.cat.ordered: # type: ignore[union-attr] - return ("ordinal", data.cat.categories.tolist()) # type: ignore[union-attr] + elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: + return ("ordinal", data.cat.categories.tolist()) elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: return "nominal" elif typ in [