Skip to content

Commit

Permalink
Type hints: Parts of folders "vegalite", "v5", and "utils" (#2976)
Browse files Browse the repository at this point in the history
* Type hint vegalite/v5/theme.py

* Type hint display modules

* Type hint data modules and some related files

* Type hint core.py

* Type hint html.py

* Format code

* Type hint default_data_transformer

* Improve type hints on data transformers

* Ignore mypy missing import error for nbformat and ipykernel. Unclear why mypy suddenly complains about those...

* Use Final for constants

* Mark some type aliases/protocols as private which are new and where we might not yet want users to rely on them. Gives more flexibility in the future to change them

* Remove unused 'type ignore' statements. Add mypy error if it detects such a redundant statement

* Replace _SupportsDataFrame with _DataFrameLike

* Format code

* Rename two typeddicts to make it clearer that they reference the return type of functions

* Make InferredVegaLiteType private

* Type hint 'data' in infer_vegalite_type as object as this is the type hint that is used by pandas for their infer_type function

---------

Co-authored-by: Jon Mease <[email protected]>
  • Loading branch information
binste and jonmmease committed Jun 18, 2023
1 parent 564b472 commit f13f8f1
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 174 deletions.
103 changes: 57 additions & 46 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
Utility routines
"""
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from copy import deepcopy
import json
import itertools
import re
import sys
import traceback
import warnings
from typing import Callable, TypeVar, Any
from typing import Callable, TypeVar, Any, Union, Dict, Optional, Tuple, Sequence, Type
from types import ModuleType

import jsonschema
import pandas as pd
Expand All @@ -23,9 +24,9 @@
from typing_extensions import ParamSpec

if sys.version_info >= (3, 8):
from typing import Protocol
from typing import Literal, Protocol
else:
from typing_extensions import Protocol
from typing_extensions import Literal, Protocol

try:
from pandas.api.types import infer_dtype as _infer_dtype
Expand All @@ -42,7 +43,7 @@ def __dataframe__(self, *args, **kwargs):
...


def infer_dtype(value):
def infer_dtype(value: object) -> str:
"""Infer the dtype of the value.
This is a compatibility function for pandas infer_dtype,
Expand All @@ -53,10 +54,10 @@ def infer_dtype(value):
_infer_dtype([1], skipna=False)
except TypeError:
# pandas < 0.21.0 don't support skipna keyword
infer_dtype._supports_skipna = False
infer_dtype._supports_skipna = False # type: ignore[attr-defined]
else:
infer_dtype._supports_skipna = True
if infer_dtype._supports_skipna:
infer_dtype._supports_skipna = True # type: ignore[attr-defined]
if infer_dtype._supports_skipna: # type: ignore[attr-defined]
return _infer_dtype(value, skipna=False)
else:
return _infer_dtype(value)
Expand Down Expand Up @@ -200,16 +201,20 @@ def infer_dtype(value):
]


def infer_vegalite_type(data):
_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]


def infer_vegalite_type(
data: object,
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
"""
From an array-like input, infer the correct vega typecode
('ordinal', 'nominal', 'quantitative', or 'temporal')
Parameters
----------
data: Numpy array or Pandas Series
data: object
"""
# Otherwise, infer based on the dtype of the input
typ = infer_dtype(data)

if typ in [
Expand All @@ -220,7 +225,7 @@ def infer_vegalite_type(data):
"complex",
]:
return "quantitative"
elif typ == "categorical" and data.cat.ordered:
elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered:
return ("ordinal", data.cat.categories.tolist())
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
return "nominal"
Expand All @@ -243,7 +248,7 @@ def infer_vegalite_type(data):
return "nominal"


def merge_props_geom(feat):
def merge_props_geom(feat: dict) -> dict:
"""
Merge properties with geometry
* Overwrites 'type' and 'geometry' entries if existing
Expand All @@ -261,7 +266,7 @@ def merge_props_geom(feat):
return props_geom


def sanitize_geo_interface(geo):
def sanitize_geo_interface(geo: MutableMapping) -> dict:
"""Santize a geo_interface to prepare it for serialization.
* Make a copy
Expand All @@ -278,23 +283,23 @@ def sanitize_geo_interface(geo):
geo[key] = geo[key].tolist()

# convert (nested) tuples to lists
geo = json.loads(json.dumps(geo))
geo_dct: dict = json.loads(json.dumps(geo))

# sanitize features
if geo["type"] == "FeatureCollection":
geo = geo["features"]
if len(geo) > 0:
for idx, feat in enumerate(geo):
geo[idx] = merge_props_geom(feat)
elif geo["type"] == "Feature":
geo = merge_props_geom(geo)
if geo_dct["type"] == "FeatureCollection":
geo_dct = geo_dct["features"]
if len(geo_dct) > 0:
for idx, feat in enumerate(geo_dct):
geo_dct[idx] = merge_props_geom(feat)
elif geo_dct["type"] == "Feature":
geo_dct = merge_props_geom(geo_dct)
else:
geo = {"type": "Feature", "geometry": geo}
geo_dct = {"type": "Feature", "geometry": geo_dct}

return geo
return geo_dct


def sanitize_dataframe(df): # noqa: C901
def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901
"""Sanitize a DataFrame to prepare it for serialization.
* Make a copy
Expand Down Expand Up @@ -433,13 +438,13 @@ def sanitize_arrow_table(pa_table):


def parse_shorthand(
shorthand,
data=None,
parse_aggregates=True,
parse_window_ops=False,
parse_timeunits=True,
parse_types=True,
):
shorthand: Union[Dict[str, Any], str],
data: Optional[pd.DataFrame] = None,
parse_aggregates: bool = True,
parse_window_ops: bool = False,
parse_timeunits: bool = True,
parse_types: bool = True,
) -> Dict[str, Any]:
"""General tool to parse shorthand values
These are of the form:
Expand Down Expand Up @@ -554,7 +559,9 @@ def parse_shorthand(
attrs = shorthand
else:
attrs = next(
exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand)
exp.match(shorthand).groupdict() # type: ignore[union-attr]
for exp in regexps
if exp.match(shorthand) is not None
)

# Handle short form of the type expression
Expand Down Expand Up @@ -629,21 +636,23 @@ def decorate(f: Callable[..., _V]) -> Callable[_P, _V]:
return decorate


def update_nested(original, update, copy=False):
def update_nested(
original: MutableMapping, update: Mapping, copy: bool = False
) -> MutableMapping:
"""Update nested dictionaries
Parameters
----------
original : dict
original : MutableMapping
the original (nested) dictionary, which will be updated in-place
update : dict
update : Mapping
the nested dictionary of updates
copy : bool, default False
if True, then copy the original dictionary rather than modifying it
Returns
-------
original : dict
original : MutableMapping
a reference to the (modified) original dict
Examples
Expand All @@ -660,7 +669,7 @@ def update_nested(original, update, copy=False):
for key, val in update.items():
if isinstance(val, Mapping):
orig_val = original.get(key, {})
if isinstance(orig_val, Mapping):
if isinstance(orig_val, MutableMapping):
original[key] = update_nested(orig_val, val)
else:
original[key] = val
Expand All @@ -669,7 +678,7 @@ def update_nested(original, update, copy=False):
return original


def display_traceback(in_ipython=True):
def display_traceback(in_ipython: bool = True):
exc_info = sys.exc_info()

if in_ipython:
Expand All @@ -685,16 +694,16 @@ def display_traceback(in_ipython=True):
traceback.print_exception(*exc_info)


def infer_encoding_types(args, kwargs, channels):
def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType):
"""Infer typed keyword arguments for args and kwargs
Parameters
----------
args : tuple
List of function args
kwargs : dict
args : Sequence
Sequence of function args
kwargs : MutableMapping
Dict of function kwargs
channels : module
channels : ModuleType
The module containing all altair encoding channel classes.
Returns
Expand All @@ -709,8 +718,10 @@ def infer_encoding_types(args, kwargs, channels):
channel_objs = (
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
)
channel_to_name = {c: c._encoding_name for c in channel_objs}
name_to_channel = {}
channel_to_name: Dict[Type[SchemaBase], str] = {
c: c._encoding_name for c in channel_objs
}
name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
if chan.__name__.endswith("Datum"):
Expand Down
Loading

0 comments on commit f13f8f1

Please sign in to comment.