From 91a1de17f97d0b1de20724f3cfa771f07741dcc7 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 2 Jul 2024 19:09:20 +0100 Subject: [PATCH] feat(typing): adds `Map` alias for `Mapping[str, Any]` and uses it in `Chart.encode` (#3458) This allows a `TypedDict` to be used, where `mypy` previously would have required exactly a `dict`. The name is not used anywhere currently in `altair` and has the benefit of being short like `dict`, but more permissive. I have only added this to `Chart.encode` as I know this does not require the mutability of `dict`, which I cannot confidently say for elsewhere in `altair`. Fixes https://github.com/vega/altair/pull/3427#discussion_r1662542866 --- altair/vegalite/v5/schema/_typing.py | 4 +- altair/vegalite/v5/schema/channels.py | 82 +++++++++++++-------------- tools/generate_schema_wrapper.py | 10 +++- tools/schemapi/utils.py | 14 +++-- 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/altair/vegalite/v5/schema/_typing.py b/altair/vegalite/v5/schema/_typing.py index 67d12b6e5..1c238c827 100644 --- a/altair/vegalite/v5/schema/_typing.py +++ b/altair/vegalite/v5/schema/_typing.py @@ -4,10 +4,12 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal, Mapping from typing_extensions import TypeAlias +Map: TypeAlias = Mapping[str, Any] + AggregateOp_T: TypeAlias = Literal[ "argmax", "argmin", diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 8a92ce045..da3365613 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -34207,65 +34207,65 @@ class _EncodingMixin: def encode( self, *args: Any, - angle: Optional[str | Angle | dict | AngleDatum | AngleValue] = Undefined, - color: Optional[str | Color | dict | ColorDatum | ColorValue] = Undefined, - column: Optional[str | Column | dict] = Undefined, - description: Optional[str | Description | dict | DescriptionValue] = Undefined, - detail: Optional[str | Detail | dict | list] = Undefined, - facet: Optional[str | Facet | dict] = Undefined, - fill: Optional[str | Fill | dict | FillDatum | FillValue] = Undefined, + angle: Optional[str | Angle | Map | AngleDatum | AngleValue] = Undefined, + color: Optional[str | Color | Map | ColorDatum | ColorValue] = Undefined, + column: Optional[str | Column | Map] = Undefined, + description: Optional[str | Description | Map | DescriptionValue] = Undefined, + detail: Optional[str | Detail | Map | list] = Undefined, + facet: Optional[str | Facet | Map] = Undefined, + fill: Optional[str | Fill | Map | FillDatum | FillValue] = Undefined, fillOpacity: Optional[ - str | FillOpacity | dict | FillOpacityDatum | FillOpacityValue + str | FillOpacity | Map | FillOpacityDatum | FillOpacityValue ] = Undefined, - href: Optional[str | Href | dict | HrefValue] = Undefined, - key: Optional[str | Key | dict] = Undefined, - latitude: Optional[str | Latitude | dict | LatitudeDatum] = Undefined, + href: Optional[str | Href | Map | HrefValue] = Undefined, + key: Optional[str | Key | Map] = Undefined, + latitude: Optional[str | Latitude | Map | LatitudeDatum] = Undefined, latitude2: Optional[ - str | Latitude2 | dict | Latitude2Datum | Latitude2Value + str | Latitude2 | Map | Latitude2Datum | Latitude2Value ] = Undefined, - longitude: Optional[str | Longitude | dict | LongitudeDatum] = Undefined, + longitude: Optional[str | Longitude | Map | LongitudeDatum] = Undefined, longitude2: Optional[ - str | Longitude2 | dict | Longitude2Datum | Longitude2Value + str | Longitude2 | Map | Longitude2Datum | Longitude2Value ] = Undefined, opacity: Optional[ - str | Opacity | dict | OpacityDatum | OpacityValue + str | Opacity | Map | OpacityDatum | OpacityValue ] = Undefined, - order: Optional[str | Order | dict | list | OrderValue] = Undefined, - radius: Optional[str | Radius | dict | RadiusDatum | RadiusValue] = Undefined, + order: Optional[str | Order | Map | list | OrderValue] = Undefined, + radius: Optional[str | Radius | Map | RadiusDatum | RadiusValue] = Undefined, radius2: Optional[ - str | Radius2 | dict | Radius2Datum | Radius2Value + str | Radius2 | Map | Radius2Datum | Radius2Value ] = Undefined, - row: Optional[str | Row | dict] = Undefined, - shape: Optional[str | Shape | dict | ShapeDatum | ShapeValue] = Undefined, - size: Optional[str | Size | dict | SizeDatum | SizeValue] = Undefined, - stroke: Optional[str | Stroke | dict | StrokeDatum | StrokeValue] = Undefined, + row: Optional[str | Row | Map] = Undefined, + shape: Optional[str | Shape | Map | ShapeDatum | ShapeValue] = Undefined, + size: Optional[str | Size | Map | SizeDatum | SizeValue] = Undefined, + stroke: Optional[str | Stroke | Map | StrokeDatum | StrokeValue] = Undefined, strokeDash: Optional[ - str | StrokeDash | dict | StrokeDashDatum | StrokeDashValue + str | StrokeDash | Map | StrokeDashDatum | StrokeDashValue ] = Undefined, strokeOpacity: Optional[ - str | StrokeOpacity | dict | StrokeOpacityDatum | StrokeOpacityValue + str | StrokeOpacity | Map | StrokeOpacityDatum | StrokeOpacityValue ] = Undefined, strokeWidth: Optional[ - str | StrokeWidth | dict | StrokeWidthDatum | StrokeWidthValue - ] = Undefined, - text: Optional[str | Text | dict | TextDatum | TextValue] = Undefined, - theta: Optional[str | Theta | dict | ThetaDatum | ThetaValue] = Undefined, - theta2: Optional[str | Theta2 | dict | Theta2Datum | Theta2Value] = Undefined, - tooltip: Optional[str | Tooltip | dict | list | TooltipValue] = Undefined, - url: Optional[str | Url | dict | UrlValue] = Undefined, - x: Optional[str | X | dict | XDatum | XValue] = Undefined, - x2: Optional[str | X2 | dict | X2Datum | X2Value] = Undefined, - xError: Optional[str | XError | dict | XErrorValue] = Undefined, - xError2: Optional[str | XError2 | dict | XError2Value] = Undefined, + str | StrokeWidth | Map | StrokeWidthDatum | StrokeWidthValue + ] = Undefined, + text: Optional[str | Text | Map | TextDatum | TextValue] = Undefined, + theta: Optional[str | Theta | Map | ThetaDatum | ThetaValue] = Undefined, + theta2: Optional[str | Theta2 | Map | Theta2Datum | Theta2Value] = Undefined, + tooltip: Optional[str | Tooltip | Map | list | TooltipValue] = Undefined, + url: Optional[str | Url | Map | UrlValue] = Undefined, + x: Optional[str | X | Map | XDatum | XValue] = Undefined, + x2: Optional[str | X2 | Map | X2Datum | X2Value] = Undefined, + xError: Optional[str | XError | Map | XErrorValue] = Undefined, + xError2: Optional[str | XError2 | Map | XError2Value] = Undefined, xOffset: Optional[ - str | XOffset | dict | XOffsetDatum | XOffsetValue + str | XOffset | Map | XOffsetDatum | XOffsetValue ] = Undefined, - y: Optional[str | Y | dict | YDatum | YValue] = Undefined, - y2: Optional[str | Y2 | dict | Y2Datum | Y2Value] = Undefined, - yError: Optional[str | YError | dict | YErrorValue] = Undefined, - yError2: Optional[str | YError2 | dict | YError2Value] = Undefined, + y: Optional[str | Y | Map | YDatum | YValue] = Undefined, + y2: Optional[str | Y2 | Map | Y2Datum | Y2Value] = Undefined, + yError: Optional[str | YError | Map | YErrorValue] = Undefined, + yError2: Optional[str | YError2 | Map | YError2Value] = Undefined, yOffset: Optional[ - str | YOffset | dict | YOffsetDatum | YOffsetValue + str | YOffset | Map | YOffsetDatum | YOffsetValue ] = Undefined, ) -> Self: """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index f1a73670f..f46096960 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -232,6 +232,12 @@ def encode({encode_method_args}) -> Self: return copy ''' +# These types should support annotations in generated code, +# but are not derived from the schema itself. +EXTRA_ALIASES: Final = """ +Map: TypeAlias = Mapping[str, Any] +""" + class SchemaGenerator(codegen.SchemaGenerator): schema_class_template = textwrap.dedent( @@ -816,7 +822,7 @@ def vegalite_main(skip_download: bool = False) -> None: f"Tracer cache collected {TypeAliasTracer.n_entries!r} entries." ) print(msg) - TypeAliasTracer.write_module(fp_typing, header=HEADER) + TypeAliasTracer.write_module(fp_typing, header=HEADER, extra_aliases=EXTRA_ALIASES) # Write the pre-generated modules for fp, contents in files.items(): print(f"Writing\n {schemafile!s}\n ->{fp!s}") @@ -844,7 +850,7 @@ def _create_encode_signature( # the dictionary representation of an encoding channel class. See # discussions in https://github.com/vega/altair/pull/3208 # for more background. - union_types = ["str", field_class_name, "dict"] + union_types = ["str", field_class_name, "Map"] docstring_union_types = ["str", rst_syntax_for_class(field_class_name), "Dict"] if info.supports_arrays: # We could be more specific about what types are accepted in the list diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index a9542b15e..4475282a8 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -71,7 +71,7 @@ def __init__( self.aliases: list[tuple[str, str]] = [] self._imports: Sequence[str] = ( "from __future__ import annotations\n", - "from typing import Literal", + "from typing import Literal, Mapping, Any", "from typing_extensions import TypeAlias", ) self._cmd_check: list[str] = ["--fix"] @@ -113,7 +113,11 @@ def generate_aliases(self) -> Iterator[str]: yield f"{name}: TypeAlias = {statement}" def write_module( - self, fp: Path, *extra_imports: str, header: LiteralString + self, + fp: Path, + *extra_imports: str, + header: LiteralString, + extra_aliases: LiteralString, ) -> None: """Write all collected `TypeAlias`'s to `fp`. @@ -125,13 +129,15 @@ def write_module( Follows `self._imports` block. header `tools.generate_schema_wrapper.HEADER`. + extra_aliases + `tools.generate_schema_wrapper.EXTRA_ALIASES`. """ ruff_format = ["ruff", "format", fp] if self._cmd_format: ruff_format.extend(self._cmd_format) commands = (["ruff", "check", fp, *self._cmd_check], ruff_format) - imports = (header, "\n", *self._imports, *extra_imports, "\n\n") - it = chain(imports, self.generate_aliases()) + static = (header, "\n", *self._imports, *extra_imports, "\n\n", extra_aliases) + it = chain(static, self.generate_aliases()) fp.write_text("\n".join(it), encoding="utf-8") for cmd in commands: r = subprocess.run(cmd, check=True)