Skip to content

Commit

Permalink
refactor(typing): Reuse generated Literal aliases in api (vega#3464)
Browse files Browse the repository at this point in the history
Following vega#3431 a number of these are now importable.
Additionally, I spotted the `encodings` parameter was annotated with `str`, but should be restricted to the constraints of `SingleDefUnitChannel_T`.
  • Loading branch information
dangotbanned authored Jul 9, 2024
1 parent eeecb66 commit e994fbd
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@
InlineDataset,
)
from altair.expr.core import Expression, GetAttrExpression
from .schema._typing import (
ImputeMethod_T,
SelectionType_T,
SelectionResolution_T,
SingleDefUnitChannel_T,
StackOffset_T,
)

ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]]

Expand Down Expand Up @@ -505,9 +512,7 @@ def param(
return parameter


def _selection(
type: Optional[Literal["interval", "point"]] = Undefined, **kwds
) -> Parameter:
def _selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter:
# We separate out the parameter keywords from the selection keywords

select_kwds = {"name", "bind", "value", "empty", "init", "views"}
Expand Down Expand Up @@ -537,9 +542,7 @@ def _selection(
message="""'selection' is deprecated.
Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings."""
)
def selection(
type: Optional[Literal["interval", "point"]] = Undefined, **kwds
) -> Parameter:
def selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter:
"""
Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create.
Expand Down Expand Up @@ -568,10 +571,10 @@ def selection_interval(
bind: Optional[Binding | str] = Undefined,
empty: Optional[bool] = Undefined,
expr: Optional[str | Expr | Expression] = Undefined,
encodings: Optional[list[str]] = Undefined,
encodings: Optional[list[SingleDefUnitChannel_T]] = Undefined,
on: Optional[str] = Undefined,
clear: Optional[str | bool] = Undefined,
resolve: Optional[Literal["global", "union", "intersect"]] = Undefined,
resolve: Optional[SelectionResolution_T] = Undefined,
mark: Optional[Mark] = Undefined,
translate: Optional[str | bool] = Undefined,
zoom: Optional[str | bool] = Undefined,
Expand Down Expand Up @@ -680,11 +683,11 @@ def selection_point(
bind: Optional[Binding | str] = Undefined,
empty: Optional[bool] = Undefined,
expr: Optional[Expr] = Undefined,
encodings: Optional[list[str]] = Undefined,
encodings: Optional[list[SingleDefUnitChannel_T]] = Undefined,
fields: Optional[list[str]] = Undefined,
on: Optional[str] = Undefined,
clear: Optional[str | bool] = Undefined,
resolve: Optional[Literal["global", "union", "intersect"]] = Undefined,
resolve: Optional[SelectionResolution_T] = Undefined,
toggle: Optional[str | bool] = Undefined,
nearest: Optional[bool] = Undefined,
**kwds,
Expand Down Expand Up @@ -1853,9 +1856,7 @@ def transform_impute(
frame: Optional[list[int | None]] = Undefined,
groupby: Optional[list[str | FieldName]] = Undefined,
keyvals: Optional[list[Any] | ImputeSequence] = Undefined,
method: Optional[
Literal["value", "mean", "median", "max", "min"] | ImputeMethod
] = Undefined,
method: Optional[ImputeMethod_T | ImputeMethod] = Undefined,
value=Undefined,
) -> Self:
"""
Expand Down Expand Up @@ -2378,7 +2379,7 @@ def transform_stack(
as_: str | FieldName | list[str],
stack: str | FieldName,
groupby: list[str | FieldName],
offset: Optional[Literal["zero", "center", "normalize"]] = Undefined,
offset: Optional[StackOffset_T] = Undefined,
sort: Optional[list[SortField]] = Undefined,
) -> Self:
"""
Expand Down Expand Up @@ -3061,7 +3062,7 @@ def interactive(
copy of self, with interactive axes added
"""
encodings = []
encodings: list[SingleDefUnitChannel_T] = []
if bind_x:
encodings.append("x")
if bind_y:
Expand Down Expand Up @@ -3351,7 +3352,7 @@ def interactive(
copy of self, with interactive axes added
"""
encodings = []
encodings: list[SingleDefUnitChannel_T] = []
if bind_x:
encodings.append("x")
if bind_y:
Expand Down Expand Up @@ -3448,7 +3449,7 @@ def interactive(
copy of self, with interactive axes added
"""
encodings = []
encodings: list[SingleDefUnitChannel_T] = []
if bind_x:
encodings.append("x")
if bind_y:
Expand Down Expand Up @@ -3547,7 +3548,7 @@ def interactive(
copy of self, with interactive axes added
"""
encodings = []
encodings: list[SingleDefUnitChannel_T] = []
if bind_x:
encodings.append("x")
if bind_y:
Expand Down

0 comments on commit e994fbd

Please sign in to comment.