diff --git a/altair/__init__.py b/altair/__init__.py index d6c03f48a..9a6496d5e 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -50,7 +50,6 @@ "ChainedWhen", "Chart", "ChartDataType", - "ChartType", "Color", "ColorDatum", "ColorDef", @@ -301,7 +300,6 @@ "Opacity", "OpacityDatum", "OpacityValue", - "Optional", "Order", "OrderFieldDef", "OrderOnlyDef", @@ -611,7 +609,6 @@ "expr", "graticule", "hconcat", - "is_chart_type", "jupyter", "layer", "limit_rows", @@ -634,6 +631,7 @@ "to_json", "to_values", "topo_feature", + "typing", "utils", "v5", "value", @@ -653,7 +651,8 @@ def __dir__(): from altair.vegalite.v5.schema.core import Dict from altair.jupyter import JupyterChart from altair.expr import expr -from altair.utils import AltairDeprecationWarning, parse_shorthand, Optional, Undefined +from altair.utils import AltairDeprecationWarning, parse_shorthand, Undefined +from altair import typing def load_ipython_extension(ipython): diff --git a/altair/typing.py b/altair/typing.py new file mode 100644 index 000000000..cd8cb1489 --- /dev/null +++ b/altair/typing.py @@ -0,0 +1,96 @@ +"""Public types to ease integrating with `altair`.""" + +from __future__ import annotations + +__all__ = [ + "ChannelAngle", + "ChannelColor", + "ChannelColumn", + "ChannelDescription", + "ChannelDetail", + "ChannelFacet", + "ChannelFill", + "ChannelFillOpacity", + "ChannelHref", + "ChannelKey", + "ChannelLatitude", + "ChannelLatitude2", + "ChannelLongitude", + "ChannelLongitude2", + "ChannelOpacity", + "ChannelOrder", + "ChannelRadius", + "ChannelRadius2", + "ChannelRow", + "ChannelShape", + "ChannelSize", + "ChannelStroke", + "ChannelStrokeDash", + "ChannelStrokeOpacity", + "ChannelStrokeWidth", + "ChannelText", + "ChannelTheta", + "ChannelTheta2", + "ChannelTooltip", + "ChannelUrl", + "ChannelX", + "ChannelX2", + "ChannelXError", + "ChannelXError2", + "ChannelXOffset", + "ChannelY", + "ChannelY2", + "ChannelYError", + "ChannelYError2", + "ChannelYOffset", + "ChartType", + "EncodeKwds", + "Optional", + "is_chart_type", +] + +from altair.utils.schemapi import Optional +from altair.vegalite.v5.api import ChartType, is_chart_type +from altair.vegalite.v5.schema.channels import ( + ChannelAngle, + ChannelColor, + ChannelColumn, + ChannelDescription, + ChannelDetail, + ChannelFacet, + ChannelFill, + ChannelFillOpacity, + ChannelHref, + ChannelKey, + ChannelLatitude, + ChannelLatitude2, + ChannelLongitude, + ChannelLongitude2, + ChannelOpacity, + ChannelOrder, + ChannelRadius, + ChannelRadius2, + ChannelRow, + ChannelShape, + ChannelSize, + ChannelStroke, + ChannelStrokeDash, + ChannelStrokeOpacity, + ChannelStrokeWidth, + ChannelText, + ChannelTheta, + ChannelTheta2, + ChannelTooltip, + ChannelUrl, + ChannelX, + ChannelX2, + ChannelXError, + ChannelXError2, + ChannelXOffset, + ChannelY, + ChannelY2, + ChannelYError, + ChannelYError2, + ChannelYOffset, + EncodeKwds, +) diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 12fe11bf5..3839a13d2 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -31,8 +31,8 @@ from altair.utils.schemapi import Undefined if TYPE_CHECKING: + from altair.typing import ChartType from altair.utils.core import DataFrameLike - from altair.vegalite.v5.api import ChartType Scope: TypeAlias = Tuple[int, ...] FacetMapping: TypeAlias = Dict[Tuple[str, Scope], Tuple[str, Scope]] @@ -452,7 +452,7 @@ def get_facet_mapping(group: dict[str, Any], scope: Scope = ()) -> FacetMapping: group, facet_data, scope ) if definition_scope is not None: - facet_mapping[(facet_name, group_scope)] = ( + facet_mapping[facet_name, group_scope] = ( facet_data, definition_scope, ) diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index de196025b..e8884ccd2 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -44,7 +44,7 @@ from referencing import Registry - from altair import ChartType + from altair.typing import ChartType if sys.version_info >= (3, 13): from typing import TypeIs @@ -777,7 +777,7 @@ def __repr__(self) -> str: The parameters ``short``, ``long`` accept the same range of types:: # ruff: noqa: UP006, UP007 - from altair import Optional + from altair.typing import Optional def func_1( short: Optional[str | bool | float | dict[str, Any] | SchemaBase] = Undefined, @@ -786,10 +786,12 @@ def func_1( ] = Undefined, ): ... -This is distinct from `typing.Optional `__ as ``altair.Optional`` treats ``None`` like any other type:: +This is distinct from `typing.Optional `__. + +``altair.typing.Optional`` treats ``None`` like any other type:: # ruff: noqa: UP006, UP007 - from altair import Optional + from altair.typing import Optional def func_2( short: Optional[str | float | dict[str, Any] | None | SchemaBase] = Undefined, diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 5a754a9ad..0b3776541 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -145,7 +145,6 @@ "ChainedWhen", "Chart", "ChartDataType", - "ChartType", "ConcatChart", "DataType", "FacetChart", @@ -174,7 +173,6 @@ "condition", "graticule", "hconcat", - "is_chart_type", "layer", "mixins", "param", diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index d41a2c96d..34daf00ab 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -11,7 +11,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Sequence, overload +from typing import TYPE_CHECKING, Any, Literal, Sequence, TypedDict, Union, overload +from typing_extensions import TypeAlias from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe @@ -20,15 +21,14 @@ from altair.utils.schemapi import Undefined, with_property_setters from . import core +from ._typing import * # noqa: F403 # ruff: noqa: F405 if TYPE_CHECKING: from typing_extensions import Self from altair import Parameter, SchemaBase - from altair.utils.schemapi import Optional - - from ._typing import * # noqa: F403 + from altair.typing import Optional __all__ = [ @@ -30897,6 +30897,58 @@ def __init__(self, value, **kwds): super().__init__(value=value, **kwds) +ChannelAngle: TypeAlias = Union[str, Angle, Map, AngleDatum, AngleValue] +ChannelColor: TypeAlias = Union[str, Color, Map, ColorDatum, ColorValue] +ChannelColumn: TypeAlias = Union[str, Column, Map] +ChannelDescription: TypeAlias = Union[str, Description, Map, DescriptionValue] +ChannelDetail: TypeAlias = OneOrSeq[Union[str, Detail, Map]] +ChannelFacet: TypeAlias = Union[str, Facet, Map] +ChannelFill: TypeAlias = Union[str, Fill, Map, FillDatum, FillValue] +ChannelFillOpacity: TypeAlias = Union[ + str, FillOpacity, Map, FillOpacityDatum, FillOpacityValue +] +ChannelHref: TypeAlias = Union[str, Href, Map, HrefValue] +ChannelKey: TypeAlias = Union[str, Key, Map] +ChannelLatitude: TypeAlias = Union[str, Latitude, Map, LatitudeDatum] +ChannelLatitude2: TypeAlias = Union[str, Latitude2, Map, Latitude2Datum, Latitude2Value] +ChannelLongitude: TypeAlias = Union[str, Longitude, Map, LongitudeDatum] +ChannelLongitude2: TypeAlias = Union[ + str, Longitude2, Map, Longitude2Datum, Longitude2Value +] +ChannelOpacity: TypeAlias = Union[str, Opacity, Map, OpacityDatum, OpacityValue] +ChannelOrder: TypeAlias = OneOrSeq[Union[str, Order, Map, OrderValue]] +ChannelRadius: TypeAlias = Union[str, Radius, Map, RadiusDatum, RadiusValue] +ChannelRadius2: TypeAlias = Union[str, Radius2, Map, Radius2Datum, Radius2Value] +ChannelRow: TypeAlias = Union[str, Row, Map] +ChannelShape: TypeAlias = Union[str, Shape, Map, ShapeDatum, ShapeValue] +ChannelSize: TypeAlias = Union[str, Size, Map, SizeDatum, SizeValue] +ChannelStroke: TypeAlias = Union[str, Stroke, Map, StrokeDatum, StrokeValue] +ChannelStrokeDash: TypeAlias = Union[ + str, StrokeDash, Map, StrokeDashDatum, StrokeDashValue +] +ChannelStrokeOpacity: TypeAlias = Union[ + str, StrokeOpacity, Map, StrokeOpacityDatum, StrokeOpacityValue +] +ChannelStrokeWidth: TypeAlias = Union[ + str, StrokeWidth, Map, StrokeWidthDatum, StrokeWidthValue +] +ChannelText: TypeAlias = Union[str, Text, Map, TextDatum, TextValue] +ChannelTheta: TypeAlias = Union[str, Theta, Map, ThetaDatum, ThetaValue] +ChannelTheta2: TypeAlias = Union[str, Theta2, Map, Theta2Datum, Theta2Value] +ChannelTooltip: TypeAlias = OneOrSeq[Union[str, Tooltip, Map, TooltipValue]] +ChannelUrl: TypeAlias = Union[str, Url, Map, UrlValue] +ChannelX: TypeAlias = Union[str, X, Map, XDatum, XValue] +ChannelX2: TypeAlias = Union[str, X2, Map, X2Datum, X2Value] +ChannelXError: TypeAlias = Union[str, XError, Map, XErrorValue] +ChannelXError2: TypeAlias = Union[str, XError2, Map, XError2Value] +ChannelXOffset: TypeAlias = Union[str, XOffset, Map, XOffsetDatum, XOffsetValue] +ChannelY: TypeAlias = Union[str, Y, Map, YDatum, YValue] +ChannelY2: TypeAlias = Union[str, Y2, Map, Y2Datum, Y2Value] +ChannelYError: TypeAlias = Union[str, YError, Map, YErrorValue] +ChannelYError2: TypeAlias = Union[str, YError2, Map, YError2Value] +ChannelYOffset: TypeAlias = Union[str, YOffset, Map, YOffsetDatum, YOffsetValue] + + class _EncodingMixin: def encode( self, @@ -30905,7 +30957,7 @@ def encode( color: Optional[str | Color | Map | ColorDatum | ColorValue] = Undefined, column: Optional[str | Column | Map] = Undefined, description: Optional[str | Description | Map | DescriptionValue] = Undefined, - detail: Optional[str | Detail | Map | list] = Undefined, + detail: Optional[OneOrSeq[str | Detail | Map]] = Undefined, facet: Optional[str | Facet | Map] = Undefined, fill: Optional[str | Fill | Map | FillDatum | FillValue] = Undefined, fillOpacity: Optional[ @@ -30924,7 +30976,7 @@ def encode( opacity: Optional[ str | Opacity | Map | OpacityDatum | OpacityValue ] = Undefined, - order: Optional[str | Order | Map | list | OrderValue] = Undefined, + order: Optional[OneOrSeq[str | Order | Map | OrderValue]] = Undefined, radius: Optional[str | Radius | Map | RadiusDatum | RadiusValue] = Undefined, radius2: Optional[ str | Radius2 | Map | Radius2Datum | Radius2Value @@ -30945,7 +30997,7 @@ def encode( text: Optional[str | Text | Map | TextDatum | TextValue] = Undefined, theta: Optional[str | Theta | Map | ThetaDatum | ThetaValue] = Undefined, theta2: Optional[str | Theta2 | Map | Theta2Datum | Theta2Value] = Undefined, - tooltip: Optional[str | Tooltip | Map | list | TooltipValue] = Undefined, + tooltip: Optional[OneOrSeq[str | Tooltip | Map | TooltipValue]] = Undefined, url: Optional[str | Url | Map | UrlValue] = Undefined, x: Optional[str | X | Map | XDatum | XValue] = Undefined, x2: Optional[str | X2 | Map | X2Datum | X2Value] = Undefined, @@ -31189,3 +31241,245 @@ def encode( encoding.update(kwargs) copy.encoding = core.FacetedEncoding(**encoding) return copy + + +class EncodeKwds(TypedDict, total=False): + """ + Encoding channels map properties of the data to visual properties of the chart. + + Parameters + ---------- + angle + Rotation angle of point and text marks. + color + Color of the marks - either fill or stroke color based on the ``filled`` property + of mark definition. By default, ``color`` represents fill color for ``"area"``, + ``"bar"``, ``"tick"``, ``"text"``, ``"trail"``, ``"circle"``, and ``"square"`` / + stroke color for ``"line"`` and ``"point"``. + + **Default value:** If undefined, the default color depends on `mark config + `__'s ``color`` + property. + + *Note:* 1) For fine-grained control over both fill and stroke colors of the marks, + please use the ``fill`` and ``stroke`` channels. The ``fill`` or ``stroke`` + encodings have higher precedence than ``color``, thus may override the ``color`` + encoding if conflicting encodings are specified. 2) See the scale documentation for + more information about customizing `color scheme + `__. + column + A field definition for the horizontal facet of trellis plots. + description + A text description of this mark for ARIA accessibility (SVG output only). For SVG + output the ``"aria-label"`` attribute will be set to this description. + detail + Additional levels of detail for grouping data in aggregate views and in line, trail, + and area marks without mapping data to a specific visual channel. + facet + A field definition for the (flexible) facet of trellis plots. + + If either ``row`` or ``column`` is specified, this channel will be ignored. + fill + Fill color of the marks. **Default value:** If undefined, the default color depends + on `mark config `__'s + ``color`` property. + + *Note:* The ``fill`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + fillOpacity + Fill opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__'s ``fillOpacity`` + property. + href + A URL to load upon mouse click. + key + A data field to use as a unique key for data binding. When a visualization's data is + updated, the key value will be used to match data elements to existing mark + instances. Use a key channel to enable object constancy for transitions over dynamic + data. + latitude + Latitude position of geographically projected marks. + latitude2 + Latitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + longitude + Longitude position of geographically projected marks. + longitude2 + Longitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + opacity + Opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__'s ``opacity`` + property. + order + Order of the marks. + + * For stacked marks, this ``order`` channel encodes `stack order + `__. + * For line and trail marks, this ``order`` channel encodes order of data points in + the lines. This can be useful for creating `a connected scatterplot + `__. Setting + ``order`` to ``{"value": null}`` makes the line marks use the original order in + the data sources. + * Otherwise, this ``order`` channel encodes layer order of the marks. + + **Note**: In aggregate plots, ``order`` field should be ``aggregate``d to avoid + creating additional aggregation grouping. + radius + The outer radius in pixels of arc marks. + radius2 + The inner radius in pixels of arc marks. + row + A field definition for the vertical facet of trellis plots. + shape + Shape of the mark. + + 1. For ``point`` marks the supported values include: - plotting shapes: + ``"circle"``, ``"square"``, ``"cross"``, ``"diamond"``, ``"triangle-up"``, + ``"triangle-down"``, ``"triangle-right"``, or ``"triangle-left"``. - the line + symbol ``"stroke"`` - centered directional shapes ``"arrow"``, ``"wedge"``, or + ``"triangle"`` - a custom `SVG path string + `__ (For correct + sizing, custom shape paths should be defined within a square bounding box with + coordinates ranging from -1 to 1 along both the x and y dimensions.) + + 2. For ``geoshape`` marks it should be a field definition of the geojson data + + **Default value:** If undefined, the default shape depends on `mark config + `__'s ``shape`` + property. (``"circle"`` if unset.) + size + Size of the mark. + + * For ``"point"``, ``"square"`` and ``"circle"``, - the symbol size, or pixel area + of the mark. + * For ``"bar"`` and ``"tick"`` - the bar and tick's size. + * For ``"text"`` - the text's font size. + * Size is unsupported for ``"line"``, ``"area"``, and ``"rect"``. (Use ``"trail"`` + instead of line with varying size) + stroke + Stroke color of the marks. **Default value:** If undefined, the default color + depends on `mark config + `__'s ``color`` + property. + + *Note:* The ``stroke`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + strokeDash + Stroke dash of the marks. + + **Default value:** ``[1,0]`` (No dash). + strokeOpacity + Stroke opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__'s + ``strokeOpacity`` property. + strokeWidth + Stroke width of the marks. + + **Default value:** If undefined, the default stroke width depends on `mark config + `__'s ``strokeWidth`` + property. + text + Text of the ``text`` mark. + theta + * For arc marks, the arc length in radians if theta2 is not specified, otherwise the + start arc angle. (A value of 0 indicates up or “north”, increasing values proceed + clockwise.) + + * For text marks, polar coordinate angle in radians. + theta2 + The end angle of arc marks in radians. A value of 0 indicates up or “north”, + increasing values proceed clockwise. + tooltip + The tooltip text to show upon mouse hover. Specifying ``tooltip`` encoding overrides + `the tooltip property in the mark definition + `__. + + See the `tooltip `__ + documentation for a detailed discussion about tooltip in Vega-Lite. + url + The URL of an image mark. + x + X coordinates of the marks, or width of horizontal ``"bar"`` and ``"area"`` without + specified ``x2`` or ``width``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + x2 + X2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + xError + Error value of x coordinates for error specified ``"errorbar"`` and ``"errorband"``. + xError2 + Secondary error value of x coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + xOffset + Offset of x-position of the marks + y + Y coordinates of the marks, or height of vertical ``"bar"`` and ``"area"`` without + specified ``y2`` or ``height``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + y2 + Y2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + yError + Error value of y coordinates for error specified ``"errorbar"`` and ``"errorband"``. + yError2 + Secondary error value of y coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + yOffset + Offset of y-position of the marks + """ + + angle: str | Angle | Map | AngleDatum | AngleValue + color: str | Color | Map | ColorDatum | ColorValue + column: str | Column | Map + description: str | Description | Map | DescriptionValue + detail: OneOrSeq[str | Detail | Map] + facet: str | Facet | Map + fill: str | Fill | Map | FillDatum | FillValue + fillOpacity: str | FillOpacity | Map | FillOpacityDatum | FillOpacityValue + href: str | Href | Map | HrefValue + key: str | Key | Map + latitude: str | Latitude | Map | LatitudeDatum + latitude2: str | Latitude2 | Map | Latitude2Datum | Latitude2Value + longitude: str | Longitude | Map | LongitudeDatum + longitude2: str | Longitude2 | Map | Longitude2Datum | Longitude2Value + opacity: str | Opacity | Map | OpacityDatum | OpacityValue + order: OneOrSeq[str | Order | Map | OrderValue] + radius: str | Radius | Map | RadiusDatum | RadiusValue + radius2: str | Radius2 | Map | Radius2Datum | Radius2Value + row: str | Row | Map + shape: str | Shape | Map | ShapeDatum | ShapeValue + size: str | Size | Map | SizeDatum | SizeValue + stroke: str | Stroke | Map | StrokeDatum | StrokeValue + strokeDash: str | StrokeDash | Map | StrokeDashDatum | StrokeDashValue + strokeOpacity: str | StrokeOpacity | Map | StrokeOpacityDatum | StrokeOpacityValue + strokeWidth: str | StrokeWidth | Map | StrokeWidthDatum | StrokeWidthValue + text: str | Text | Map | TextDatum | TextValue + theta: str | Theta | Map | ThetaDatum | ThetaValue + theta2: str | Theta2 | Map | Theta2Datum | Theta2Value + tooltip: OneOrSeq[str | Tooltip | Map | TooltipValue] + url: str | Url | Map | UrlValue + x: str | X | Map | XDatum | XValue + x2: str | X2 | Map | X2Datum | X2Value + xError: str | XError | Map | XErrorValue + xError2: str | XError2 | Map | XError2Value + xOffset: str | XOffset | Map | XOffsetDatum | XOffsetValue + y: str | Y | Map | YDatum | YValue + y2: str | Y2 | Map | Y2Datum | Y2Value + yError: str | YError | Map | YErrorValue + yError2: str | YError2 | Map | YError2Value + yOffset: str | YOffset | Map | YOffsetDatum | YOffsetValue diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 1a1db541d..0e24cd59f 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -17,7 +17,7 @@ # ruff: noqa: F405 if TYPE_CHECKING: from altair import Parameter - from altair.utils.schemapi import Optional + from altair.typing import Optional from ._typing import * # noqa: F403 diff --git a/altair/vegalite/v5/schema/mixins.py b/altair/vegalite/v5/schema/mixins.py index bc081e7df..940164158 100644 --- a/altair/vegalite/v5/schema/mixins.py +++ b/altair/vegalite/v5/schema/mixins.py @@ -22,7 +22,7 @@ # ruff: noqa: F405 if TYPE_CHECKING: - from altair.utils.schemapi import Optional + from altair.typing import Optional from ._typing import * # noqa: F403 diff --git a/doc/user_guide/api.rst b/doc/user_guide/api.rst index f6dac987e..eaa9cb602 100644 --- a/doc/user_guide/api.rst +++ b/doc/user_guide/api.rst @@ -152,7 +152,6 @@ API Functions condition graticule hconcat - is_chart_type layer param repeat @@ -638,3 +637,57 @@ API Utility Classes When Then ChainedWhen + +Typing +------ +.. currentmodule:: altair.typing + +.. autosummary:: + :toctree: generated/typing/ + :nosignatures: + + ChannelAngle + ChannelColor + ChannelColumn + ChannelDescription + ChannelDetail + ChannelFacet + ChannelFill + ChannelFillOpacity + ChannelHref + ChannelKey + ChannelLatitude + ChannelLatitude2 + ChannelLongitude + ChannelLongitude2 + ChannelOpacity + ChannelOrder + ChannelRadius + ChannelRadius2 + ChannelRow + ChannelShape + ChannelSize + ChannelStroke + ChannelStrokeDash + ChannelStrokeOpacity + ChannelStrokeWidth + ChannelText + ChannelTheta + ChannelTheta2 + ChannelTooltip + ChannelUrl + ChannelX + ChannelX2 + ChannelXError + ChannelXError2 + ChannelXOffset + ChannelY + ChannelY2 + ChannelYError + ChannelYError2 + ChannelYOffset + ChartType + EncodeKwds + Optional + is_chart_type + diff --git a/pyproject.toml b/pyproject.toml index 4f6d8f1ae..973963c62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -376,7 +376,12 @@ split-on-trailing-comma = false [tool.ruff.lint.flake8-tidy-imports.banned-api] # https://docs.astral.sh/ruff/settings/#lint_flake8-tidy-imports_banned-api -"typing.Optional".msg = "Use `Union[T, None]` instead.\n`typing.Optional` is likely to be confused with `altair.Optional`, which have a similar but different semantic meaning.\nSee https://github.com/vega/altair/pull/3449" +"typing.Optional".msg = """ +Use `Union[T, None]` instead. +`typing.Optional` is likely to be confused with `altair.typing.Optional`, \ +which have a similar but different semantic meaning. +See https://github.com/vega/altair/pull/3449 +""" [tool.ruff.lint.per-file-ignores] # Only enforce type annotation rules on public api diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 5d3f73c95..ecfcbcace 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -22,7 +22,7 @@ from packaging.version import Version import altair as alt -from altair.utils.schemapi import Undefined +from altair.utils.schemapi import Optional, Undefined try: import vl_convert as vlc @@ -667,9 +667,9 @@ def test_when_multiple_fields(): alt.selection_point(fields=["Horsepower"]), ], ) -@pytest.mark.parametrize("empty", [alt.Undefined, True, False]) +@pytest.mark.parametrize("empty", [Undefined, True, False]) def test_when_condition_parity( - cars, channel: str, when, empty: alt.Optional[bool], then, otherwise + cars, channel: str, when, empty: Optional[bool], then, otherwise ): params = [when] if isinstance(when, alt.Parameter) else () kwds = {"x": "Cylinders:N", "y": "Origin:N"} diff --git a/tools/generate_api_docs.py b/tools/generate_api_docs.py index ed655ab69..d3771d6b7 100644 --- a/tools/generate_api_docs.py +++ b/tools/generate_api_docs.py @@ -72,6 +72,17 @@ :nosignatures: {api_classes} + +Typing +------ +.. currentmodule:: altair.typing + +.. autosummary:: + :toctree: generated/typing/ + :nosignatures: + + {typing_objects} + """ @@ -109,7 +120,8 @@ def api_functions() -> list[str]: altair_api_functions = [ obj_name for obj_name in iter_objects(alt.api, restrict_to_type=types.FunctionType) # type: ignore[attr-defined] - if obj_name not in {"cast", "overload", "NamedTuple", "TypedDict"} + if obj_name + not in {"cast", "overload", "NamedTuple", "TypedDict", "is_chart_type"} ] return sorted(altair_api_functions) @@ -119,6 +131,10 @@ def api_classes() -> list[str]: return ["expr", "When", "Then", "ChainedWhen"] +def type_hints() -> list[str]: + return [s for s in sorted(iter_objects(alt.typing)) if s != "annotations"] + + def lowlevel_wrappers() -> list[str]: objects = sorted(iter_objects(alt.schema.core, restrict_to_subclass=alt.SchemaBase)) # type: ignore[attr-defined] # The names of these two classes are also used for classes in alt.channels. Due to @@ -140,6 +156,7 @@ def write_api_file() -> None: encoding_wrappers=sep.join(encoding_wrappers()), lowlevel_wrappers=sep.join(lowlevel_wrappers()), api_classes=sep.join(api_classes()), + typing_objects=sep.join(type_hints()), ), encoding="utf-8", ) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index df854b78a..8b4e7dd2e 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -209,7 +209,7 @@ def configure_{prop}(self, *args, **kwargs) -> Self: ENCODE_METHOD: Final = ''' class _EncodingMixin: - def encode({encode_method_args}) -> Self: + def encode({method_args}) -> Self: """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) {docstring}""" # Compat prep for `infer_encoding_types` signature @@ -233,6 +233,14 @@ def encode({encode_method_args}) -> Self: return copy ''' +ENCODE_TYPED_DICT: Final = ''' +class EncodeKwds(TypedDict, total=False): + """Encoding channels map properties of the data to visual properties of the chart. + {docstring}""" + {channels} + +''' + # NOTE: Not yet reasonable to generalize `TypeAliasType`, `TypeVar` # Revisit if this starts to become more common TYPING_EXTRA: Final = ''' @@ -529,7 +537,7 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: "from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses # noqa: F401\n", _type_checking_only_imports( "from altair import Parameter", - "from altair.utils.schemapi import Optional", + "from altair.typing import Optional", "from ._typing import * # noqa: F403", ), "\n" f"__all__ = {all_}\n", @@ -562,18 +570,29 @@ def _type_checking_only_imports(*imports: str) -> str: class ChannelInfo: supports_arrays: bool deep_description: str - field_class_name: str | None = None + field_class_name: str datum_class_name: str | None = None value_class_name: str | None = None + @property + def is_field_only(self) -> bool: + return not (self.datum_class_name or self.value_class_name) + @property def all_names(self) -> Iterator[str]: - if self.field_class_name: - yield self.field_class_name - if self.datum_class_name: - yield self.datum_class_name - if self.value_class_name: - yield self.value_class_name + """All channels are expected to have a field class.""" + yield self.field_class_name + yield from self.non_field_names + + @property + def non_field_names(self) -> Iterator[str]: + if self.is_field_only: + yield from () + else: + if self.datum_class_name: + yield self.datum_class_name + if self.value_class_name: + yield self.value_class_name def generate_vegalite_channel_wrappers( @@ -595,50 +614,37 @@ def generate_vegalite_channel_wrappers( supports_arrays = any( schema_info.is_array() for schema_info in propschema.anyOf ) + classname: str = prop[0].upper() + prop[1:] channel_info = ChannelInfo( supports_arrays=supports_arrays, deep_description=propschema.deep_description, + field_class_name=classname, ) for encoding_spec, definition in def_dict.items(): - classname = prop[0].upper() + prop[1:] basename = definition.rsplit("/", maxsplit=1)[-1] basename = get_valid_identifier(basename) + gen: SchemaGenerator defschema = {"$ref": definition} - - Generator: ( - type[FieldSchemaGenerator] - | type[DatumSchemaGenerator] - | type[ValueSchemaGenerator] - ) + kwds = { + "basename": basename, + "schema": defschema, + "rootschema": schema, + "encodingname": prop, + "haspropsetters": True, + } if encoding_spec == "field": - Generator = FieldSchemaGenerator - nodefault = [] - channel_info.field_class_name = classname - + gen = FieldSchemaGenerator(classname, nodefault=[], **kwds) elif encoding_spec == "datum": - Generator = DatumSchemaGenerator - classname += "Datum" - nodefault = ["datum"] - channel_info.datum_class_name = classname - + temp_name = f"{classname}Datum" + channel_info.datum_class_name = temp_name + gen = DatumSchemaGenerator(temp_name, nodefault=["datum"], **kwds) elif encoding_spec == "value": - Generator = ValueSchemaGenerator - classname += "Value" - nodefault = ["value"] - channel_info.value_class_name = classname - - gen = Generator( - classname=classname, - basename=basename, - schema=defschema, - rootschema=schema, - encodingname=prop, - nodefault=nodefault, - haspropsetters=True, - altair_classes_prefix="core", - ) + temp_name = f"{classname}Value" + channel_info.value_class_name = temp_name + gen = ValueSchemaGenerator(temp_name, nodefault=["value"], **kwds) + class_defs.append(gen.schema_class()) channel_infos[prop] = channel_info @@ -656,12 +662,14 @@ def generate_vegalite_channel_wrappers( imports = imports or [ "from __future__ import annotations\n", - "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING", + "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING, TypedDict", + "from typing_extensions import TypeAlias", "from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe", "from altair.utils.schemapi import Undefined, with_property_setters", "from altair.utils import infer_encoding_types as _infer_encoding_types", "from altair.utils import parse_shorthand", "from . import core", + "from ._typing import * # noqa: F403", ] contents = [ HEADER, @@ -669,18 +677,14 @@ def generate_vegalite_channel_wrappers( *imports, _type_checking_only_imports( "from altair import Parameter, SchemaBase", - "from altair.utils.schemapi import Optional", - "from ._typing import * # noqa: F403", + "from altair.typing import Optional", "from typing_extensions import Self", ), "\n" f"__all__ = {sorted(all_)}\n", CHANNEL_MIXINS, *class_defs, + *generate_encoding_artifacts(channel_infos, ENCODE_METHOD, ENCODE_TYPED_DICT), ] - - # Generate the type signature for the encode method - encode_signature = _create_encode_signature(channel_infos) - contents.append(encode_signature) return "\n".join(contents) @@ -832,7 +836,7 @@ def vegalite_main(skip_download: bool = False) -> None: "\n\n", _type_checking_only_imports( "from altair import Parameter, SchemaBase", - "from altair.utils.schemapi import Optional", + "from altair.typing import Optional", "from ._typing import * # noqa: F403", ), "\n\n\n", @@ -861,59 +865,68 @@ def vegalite_main(skip_download: bool = False) -> None: ruff_write_lint_format_str(fp, contents) -def _create_encode_signature( - channel_infos: dict[str, ChannelInfo], -) -> str: +def generate_encoding_artifacts( + channel_infos: dict[str, ChannelInfo], fmt_method: str, fmt_typed_dict: str +) -> Iterator[str]: + """ + Generate ``Chart.encode()`` and related typing structures. + + - `TypeAlias`(s) for each parameter to ``Chart.encode()`` + - Mixin class that provides the ``Chart.encode()`` method + - `TypedDict`, utilising/describing these structures as part of https://github.com/pola-rs/polars/pull/17995. + + Notes + ----- + - `Map`/`Dict` stands for the return types of `alt.(datum|value)`, and any encoding channel class. + - See discussions in https://github.com/vega/altair/pull/3208 + - We could be more specific about what types are accepted in the `List` + - but this translates poorly to an IDE + - `info.supports_arrays` + """ signature_args: list[str] = ["self", "*args: Any"] - docstring_parameters: list[str] = ["", "Parameters", "----------"] + type_aliases: list[str] = [] + typed_dict_args: list[str] = [] + signature_doc_params: list[str] = ["", "Parameters", "----------"] + typed_dict_doc_params: list[str] = ["", "Parameters", "----------"] + for channel, info in channel_infos.items(): - field_class_name = info.field_class_name - assert ( - field_class_name is not None - ), "All channels are expected to have a field class" - datum_and_value_class_names = [] - if info.datum_class_name is not None: - datum_and_value_class_names.append(info.datum_class_name) - - if info.value_class_name is not None: - datum_and_value_class_names.append(info.value_class_name) - - # dict stands for the return types of alt.datum, alt.value as well as - # the dictionary representation of an encoding channel class. See - # discussions in https://github.com/vega/altair/pull/3208 - # for more background. - union_types = ["str", field_class_name, "Map"] - docstring_union_types = ["str", rst_syntax_for_class(field_class_name), "Dict"] + alias_name: str = f"Channel{channel[0].upper()}{channel[1:]}" + + it: Iterator[str] = info.all_names + it_rst_names: Iterator[str] = (rst_syntax_for_class(c) for c in info.all_names) + + docstring_types: list[str] = ["str", next(it_rst_names), "Dict"] + tp_inner: str = ", ".join(chain(("str", next(it), "Map"), it)) + tp_inner = f"Union[{tp_inner}]" + if info.supports_arrays: - # We could be more specific about what types are accepted in the list - # but then the signatures would get rather long and less useful - # to a user when it shows up in their IDE. - union_types.append("list") - docstring_union_types.append("List") - - union_types = union_types + datum_and_value_class_names - docstring_union_types = docstring_union_types + [ - rst_syntax_for_class(c) for c in datum_and_value_class_names - ] + docstring_types.append("List") + tp_inner = f"OneOrSeq[{tp_inner}]" - signature_args.append( - f"{channel}: Optional[Union[{', '.join(union_types)}]] = Undefined" - ) + doc_types_flat: str = ", ".join(chain(docstring_types, it_rst_names)) - docstring_parameters.extend( - ( - f"{channel} : {', '.join(docstring_union_types)}", - f" {process_description(info.deep_description)}", - ) - ) - if len(docstring_parameters) > 1: - docstring_parameters += [""] - docstring = indent_docstring( - docstring_parameters, indent_level=8, width=100, lstrip=False + type_aliases.append(f"{alias_name}: TypeAlias = {tp_inner}") + # We use the full type hints instead of the alias in the signatures below + # as IDEs such as VS Code would else show the name of the alias instead + # of the expanded full type hints. The later are more useful to users. + typed_dict_args.append(f"{channel}: {tp_inner}") + signature_args.append(f"{channel}: Optional[{tp_inner}] = Undefined") + + description: str = f" {process_description(info.deep_description)}" + + signature_doc_params.extend((f"{channel} : {doc_types_flat}", description)) + typed_dict_doc_params.extend((f"{channel}", description)) + + method: str = fmt_method.format( + method_args=", ".join(signature_args), + docstring=indent_docstring(signature_doc_params, indent_level=8, lstrip=False), ) - return ENCODE_METHOD.format( - encode_method_args=", ".join(signature_args), docstring=docstring + typed_dict: str = fmt_typed_dict.format( + channels="\n ".join(typed_dict_args), + docstring=indent_docstring(typed_dict_doc_params, indent_level=4, lstrip=False), ) + artifacts: Iterable[str] = *type_aliases, method, typed_dict + yield from artifacts def main() -> None: diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index 3d232240d..0533964b4 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -105,7 +105,6 @@ class SchemaGenerator: rootschemarepr : CodeSnippet or object, optional An object whose repr will be used in the place of the explicit root schema. - altair_classes_prefix : string, optional **kwargs : dict Additional keywords for derived classes. """ @@ -141,7 +140,6 @@ def __init__( rootschemarepr: object | None = None, nodefault: list[str] | None = None, haspropsetters: bool = False, - altair_classes_prefix: str | None = None, **kwargs, ) -> None: self.classname = classname @@ -153,7 +151,6 @@ def __init__( self.nodefault = nodefault or () self.haspropsetters = haspropsetters self.kwargs = kwargs - self.altair_classes_prefix = altair_classes_prefix def subclasses(self) -> list[str]: """Return a list of subclass names, if any.""" @@ -226,16 +223,9 @@ def docstring(self, indent: int = 0) -> str: ): propinfo = info.properties[prop] doc += [ - "{} : {}".format( - prop, - propinfo.get_python_type_representation( - altair_classes_prefix=self.altair_classes_prefix, - ), - ), + f"{prop} : {propinfo.get_python_type_representation()}", f" {self._process_description(propinfo.deep_description)}", ] - if len(doc) > 1: - doc += [""] return indent_docstring(doc, indent_level=indent, width=100, lstrip=True) def init_code(self, indent: int = 0) -> str: @@ -279,9 +269,7 @@ def init_args( [ *additional_types, *info.properties[p].get_python_type_representation( - for_type_hints=True, - altair_classes_prefix=self.altair_classes_prefix, - return_as_str=False, + for_type_hints=True, return_as_str=False ), ] ) @@ -315,9 +303,7 @@ def get_args(self, si: SchemaInfo) -> list[str]: [ f"{p}: " + info.get_python_type_representation( - for_type_hints=True, - altair_classes_prefix=self.altair_classes_prefix, - additional_type_hints=["UndefinedType"], + for_type_hints=True, additional_type_hints=["UndefinedType"] ) + " = Undefined" for p, info in prop_infos.items() diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 0ef56ccf7..bd1827a89 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -42,7 +42,7 @@ from referencing import Registry - from altair import ChartType + from altair.typing import ChartType if sys.version_info >= (3, 13): from typing import TypeIs @@ -775,7 +775,7 @@ def __repr__(self) -> str: The parameters ``short``, ``long`` accept the same range of types:: # ruff: noqa: UP006, UP007 - from altair import Optional + from altair.typing import Optional def func_1( short: Optional[str | bool | float | dict[str, Any] | SchemaBase] = Undefined, @@ -784,10 +784,12 @@ def func_1( ] = Undefined, ): ... -This is distinct from `typing.Optional `__ as ``altair.Optional`` treats ``None`` like any other type:: +This is distinct from `typing.Optional `__. + +``altair.typing.Optional`` treats ``None`` like any other type:: # ruff: noqa: UP006, UP007 - from altair import Optional + from altair.typing import Optional def func_2( short: Optional[str | float | dict[str, Any] | None | SchemaBase] = Undefined, diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 9ff3953f0..17326a8a1 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -10,7 +10,16 @@ from html import unescape from itertools import chain from operator import itemgetter -from typing import TYPE_CHECKING, Any, Final, Iterable, Iterator, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Final, + Iterable, + Iterator, + Literal, + Sequence, + overload, +) import mistune from mistune.renderers.rst import RSTRenderer as _RSTRenderer @@ -363,21 +372,30 @@ def title(self) -> str: else: return "" + @overload + def get_python_type_representation( + self, + for_type_hints: bool = ..., + return_as_str: Literal[True] = ..., + additional_type_hints: list[str] | None = ..., + ) -> str: ... + @overload + def get_python_type_representation( + self, + for_type_hints: bool = ..., + return_as_str: Literal[False] = ..., + additional_type_hints: list[str] | None = ..., + ) -> list[str]: ... def get_python_type_representation( self, for_type_hints: bool = False, - altair_classes_prefix: str | None = None, return_as_str: bool = True, additional_type_hints: list[str] | None = None, ) -> str | list[str]: - # This is a list of all types which can be used for the current SchemaInfo. - # This includes Altair classes, standard Python types, etc. type_representations: list[str] = [] - TP_CHECK_ONLY = {"Parameter", "SchemaBase"} - """Most common annotations are include in `TYPE_CHECKING` block. - They do not require `core.` prefix, and this saves many lines of code. - - Eventually a more robust solution would apply to more types from `core`. + """ + All types which can be used for the current `SchemaInfo`. + Including `altair` classes, standard `python` types, etc. """ if self.title: @@ -385,7 +403,8 @@ def get_python_type_representation( # To keep type hints simple, we only use the SchemaBase class # as the type hint for all classes which inherit from it. class_names = ["SchemaBase"] - if self.title == "ExprRef": + if self.title in {"ExprRef", "ParameterExtent"}: + class_names.append("Parameter") # In these cases, a value parameter is also always accepted. # It would be quite complex to further differentiate # between a value and a selection parameter based on @@ -393,23 +412,7 @@ def get_python_type_representation( # try to check for the type of the Parameter.param attribute # but then we would need to write some overload signatures for # api.param). - class_names.append("Parameter") - if self.title == "ParameterExtent": - class_names.append("Parameter") - prefix = ( - "" if not altair_classes_prefix else altair_classes_prefix + "." - ) - # If there is no prefix, it might be that the class is defined - # in the same script and potentially after this line -> We use - # deferred type annotations using quotation marks. - if not prefix: - class_names = [f'"{n}"' for n in class_names] - else: - class_names = ( - n if n in TP_CHECK_ONLY else f"{prefix}{n}" for n in class_names - ) - # class_names = [f"{prefix}{n}" for n in class_names] type_representations.extend(class_names) else: # use RST syntax for generated sphinx docs @@ -427,26 +430,22 @@ def get_python_type_representation( tp_str = TypeAliasTracer.add_literal(self, spell_literal(it), replace=True) type_representations.append(tp_str) elif self.is_anyOf(): - type_representations.extend( - [ - s.get_python_type_representation( - for_type_hints=for_type_hints, - altair_classes_prefix=altair_classes_prefix, - return_as_str=False, - ) - for s in self.anyOf - ] + it = ( + s.get_python_type_representation( + for_type_hints=for_type_hints, return_as_str=False + ) + for s in self.anyOf ) + type_representations.extend(it) elif isinstance(self.type, list): options = [] subschema = SchemaInfo(dict(**self.schema)) for typ_ in self.type: subschema.schema["type"] = typ_ + # We always use title if possible for nested objects options.append( subschema.get_python_type_representation( - # We always use title if possible for nested objects - for_type_hints=for_type_hints, - altair_classes_prefix=altair_classes_prefix, + for_type_hints=for_type_hints ) ) type_representations.extend(options) @@ -469,14 +468,10 @@ def get_python_type_representation( # method. However, it is not entirely accurate as some sequences # such as e.g. a range are not supported by SchemaBase.to_dict but # this tradeoff seems worth it. - type_representations.append( - "Sequence[{}]".format( - self.child(self.items).get_python_type_representation( - for_type_hints=for_type_hints, - altair_classes_prefix=altair_classes_prefix, - ) - ) + s = self.child(self.items).get_python_type_representation( + for_type_hints=for_type_hints ) + type_representations.append(f"Sequence[{s}]") elif self.type in jsonschema_to_python_types: type_representations.append(jsonschema_to_python_types[self.type]) else: @@ -692,6 +687,8 @@ def indent_docstring( ) -> str: """Indent a docstring for use in generated code.""" final_lines = [] + if len(lines) > 1: + lines += [""] for i, line in enumerate(lines): stripped = line.lstrip()