Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hints: Parts of folders "vegalite", "v5", and "utils" #2976

Merged
merged 21 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
Utility routines
"""
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from copy import deepcopy
import json
import itertools
import re
import sys
import traceback
import warnings
from typing import Callable, TypeVar, Any
from typing import Callable, TypeVar, Any, Union, Dict, Optional, Tuple, Sequence, Type
from types import ModuleType

import jsonschema
import pandas as pd
Expand All @@ -23,9 +24,9 @@
from typing_extensions import ParamSpec

if sys.version_info >= (3, 8):
from typing import Protocol
from typing import Literal, Protocol
else:
from typing_extensions import Protocol
from typing_extensions import Literal, Protocol

try:
from pandas.api.types import infer_dtype as _infer_dtype
Expand All @@ -42,7 +43,7 @@ def __dataframe__(self, *args, **kwargs):
...


def infer_dtype(value):
def infer_dtype(value: object) -> str:
"""Infer the dtype of the value.

This is a compatibility function for pandas infer_dtype,
Expand All @@ -53,10 +54,10 @@ def infer_dtype(value):
_infer_dtype([1], skipna=False)
except TypeError:
# pandas < 0.21.0 don't support skipna keyword
infer_dtype._supports_skipna = False
infer_dtype._supports_skipna = False # type: ignore[attr-defined]
else:
infer_dtype._supports_skipna = True
if infer_dtype._supports_skipna:
infer_dtype._supports_skipna = True # type: ignore[attr-defined]
if infer_dtype._supports_skipna: # type: ignore[attr-defined]
return _infer_dtype(value, skipna=False)
else:
return _infer_dtype(value)
Expand Down Expand Up @@ -200,16 +201,20 @@ def infer_dtype(value):
]


def infer_vegalite_type(data):
_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]


def infer_vegalite_type(
data: object,
) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]:
"""
From an array-like input, infer the correct vega typecode
('ordinal', 'nominal', 'quantitative', or 'temporal')

Parameters
----------
data: Numpy array or Pandas Series
data: object
"""
# Otherwise, infer based on the dtype of the input
typ = infer_dtype(data)

if typ in [
Expand All @@ -220,7 +225,7 @@ def infer_vegalite_type(data):
"complex",
]:
return "quantitative"
elif typ == "categorical" and data.cat.ordered:
elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered:
return ("ordinal", data.cat.categories.tolist())
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
return "nominal"
Expand All @@ -243,7 +248,7 @@ def infer_vegalite_type(data):
return "nominal"


def merge_props_geom(feat):
def merge_props_geom(feat: dict) -> dict:
"""
Merge properties with geometry
* Overwrites 'type' and 'geometry' entries if existing
Expand All @@ -261,7 +266,7 @@ def merge_props_geom(feat):
return props_geom


def sanitize_geo_interface(geo):
def sanitize_geo_interface(geo: MutableMapping) -> dict:
"""Santize a geo_interface to prepare it for serialization.

* Make a copy
Expand All @@ -278,23 +283,23 @@ def sanitize_geo_interface(geo):
geo[key] = geo[key].tolist()

# convert (nested) tuples to lists
geo = json.loads(json.dumps(geo))
geo_dct: dict = json.loads(json.dumps(geo))

# sanitize features
if geo["type"] == "FeatureCollection":
geo = geo["features"]
if len(geo) > 0:
for idx, feat in enumerate(geo):
geo[idx] = merge_props_geom(feat)
elif geo["type"] == "Feature":
geo = merge_props_geom(geo)
if geo_dct["type"] == "FeatureCollection":
geo_dct = geo_dct["features"]
if len(geo_dct) > 0:
for idx, feat in enumerate(geo_dct):
geo_dct[idx] = merge_props_geom(feat)
elif geo_dct["type"] == "Feature":
geo_dct = merge_props_geom(geo_dct)
else:
geo = {"type": "Feature", "geometry": geo}
geo_dct = {"type": "Feature", "geometry": geo_dct}

return geo
return geo_dct


def sanitize_dataframe(df): # noqa: C901
def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901
"""Sanitize a DataFrame to prepare it for serialization.

* Make a copy
Expand Down Expand Up @@ -433,13 +438,13 @@ def sanitize_arrow_table(pa_table):


def parse_shorthand(
shorthand,
data=None,
parse_aggregates=True,
parse_window_ops=False,
parse_timeunits=True,
parse_types=True,
):
shorthand: Union[Dict[str, Any], str],
data: Optional[pd.DataFrame] = None,
parse_aggregates: bool = True,
parse_window_ops: bool = False,
parse_timeunits: bool = True,
parse_types: bool = True,
) -> Dict[str, Any]:
"""General tool to parse shorthand values

These are of the form:
Expand Down Expand Up @@ -554,7 +559,9 @@ def parse_shorthand(
attrs = shorthand
else:
attrs = next(
exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand)
exp.match(shorthand).groupdict() # type: ignore[union-attr]
for exp in regexps
if exp.match(shorthand) is not None
)

# Handle short form of the type expression
Expand Down Expand Up @@ -629,21 +636,23 @@ def decorate(f: Callable[..., _V]) -> Callable[_P, _V]:
return decorate


def update_nested(original, update, copy=False):
def update_nested(
original: MutableMapping, update: Mapping, copy: bool = False
) -> MutableMapping:
"""Update nested dictionaries

Parameters
----------
original : dict
original : MutableMapping
the original (nested) dictionary, which will be updated in-place
update : dict
update : Mapping
the nested dictionary of updates
copy : bool, default False
if True, then copy the original dictionary rather than modifying it

Returns
-------
original : dict
original : MutableMapping
a reference to the (modified) original dict

Examples
Expand All @@ -660,7 +669,7 @@ def update_nested(original, update, copy=False):
for key, val in update.items():
if isinstance(val, Mapping):
orig_val = original.get(key, {})
if isinstance(orig_val, Mapping):
if isinstance(orig_val, MutableMapping):
original[key] = update_nested(orig_val, val)
else:
original[key] = val
Expand All @@ -669,7 +678,7 @@ def update_nested(original, update, copy=False):
return original


def display_traceback(in_ipython=True):
def display_traceback(in_ipython: bool = True):
exc_info = sys.exc_info()

if in_ipython:
Expand All @@ -685,16 +694,16 @@ def display_traceback(in_ipython=True):
traceback.print_exception(*exc_info)


def infer_encoding_types(args, kwargs, channels):
def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType):
"""Infer typed keyword arguments for args and kwargs

Parameters
----------
args : tuple
List of function args
kwargs : dict
args : Sequence
Sequence of function args
kwargs : MutableMapping
Dict of function kwargs
channels : module
channels : ModuleType
The module containing all altair encoding channel classes.

Returns
Expand All @@ -709,8 +718,10 @@ def infer_encoding_types(args, kwargs, channels):
channel_objs = (
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
)
channel_to_name = {c: c._encoding_name for c in channel_objs}
name_to_channel = {}
channel_to_name: Dict[Type[SchemaBase], str] = {
c: c._encoding_name for c in channel_objs
}
name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
if chan.__name__.endswith("Datum"):
Expand Down
Loading