diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 26415c5f8..b6ce8b66f 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -513,7 +513,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: _PredicateType: TypeAlias = Union[ Parameter, core.Expr, - Map, + "_ConditionExtra", _TestPredicateType, _expr_core.OperatorMixin, ] @@ -538,12 +538,6 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: ``` """ -_ConditionType: TypeAlias = t.Dict[str, Union[_TestPredicateType, Any]] -"""Intermediate type representing a converted `_PredicateType`. - -Prior to parsing any `_StatementType`. -""" - _LiteralValue: TypeAlias = Union[str, bool, float, int] """Primitive python value types.""" @@ -560,15 +554,15 @@ def _is_test_predicate(obj: Any) -> TypeIs[_TestPredicateType]: return isinstance(obj, (str, _expr_core.Expression, core.PredicateComposition)) -def _get_predicate_expr(p: Parameter) -> Optional[str | SchemaBase]: +def _get_predicate_expr(p: Parameter) -> Optional[_TestPredicateType]: # https://vega.github.io/vega-lite/docs/predicate.html return getattr(p.param, "expr", Undefined) def _predicate_to_condition( predicate: _PredicateType, *, empty: Optional[bool] = Undefined -) -> _ConditionType: - condition: _ConditionType +) -> _Condition: + condition: _Condition if isinstance(predicate, Parameter): predicate_expr = _get_predicate_expr(predicate) if predicate.param_type == "selection" or utils.is_undefined(predicate_expr): @@ -595,12 +589,12 @@ def _predicate_to_condition( def _condition_to_selection( - condition: _ConditionType, + condition: _Condition, if_true: _StatementType, if_false: _StatementType, **kwargs: Any, -) -> SchemaBase | dict[str, _ConditionType | Any]: - selection: SchemaBase | dict[str, _ConditionType | Any] +) -> SchemaBase | _Conditional[_Condition]: + selection: SchemaBase | _Conditional[_Condition] if isinstance(if_true, SchemaBase): if_true = if_true.to_dict() elif isinstance(if_true, str): @@ -614,17 +608,18 @@ def _condition_to_selection( else: if_true = utils.parse_shorthand(if_true) if_true.update(kwargs) - condition.update(if_true) + cond_mutable: Any = dict(condition) + cond_mutable.update(if_true) if isinstance(if_false, SchemaBase): # For the selection, the channel definitions all allow selections # already. So use this SchemaBase wrapper if possible. selection = if_false.copy() - selection.condition = condition + selection.condition = cond_mutable elif isinstance(if_false, (str, dict)): if isinstance(if_false, str): if_false = utils.parse_shorthand(if_false) if_false.update(kwargs) - selection = dict(condition=condition, **if_false) + selection = _Conditional(condition=cond_mutable, **if_false) # type: ignore[typeddict-item] else: raise TypeError(if_false) return selection @@ -785,7 +780,7 @@ def _parse_when( *more_predicates: _ComposablePredicateType, empty: Optional[bool], **constraints: _FieldEqualType, -) -> _ConditionType: +) -> _Condition: composed: _PredicateType if utils.is_undefined(predicate): if more_predicates or constraints: @@ -842,7 +837,7 @@ def _parse_otherwise( class _BaseWhen(Protocol): # NOTE: Temporary solution to non-SchemaBase copy - _condition: _ConditionType + _condition: _Condition def _when_then( self, statement: _StatementType, kwds: dict[str, Any], / @@ -866,7 +861,7 @@ class When(_BaseWhen): `polars.when `__ """ - def __init__(self, condition: _ConditionType, /) -> None: + def __init__(self, condition: _Condition, /) -> None: self._condition = condition def __repr__(self) -> str: @@ -1129,7 +1124,7 @@ class ChainedWhen(_BaseWhen): def __init__( self, - condition: _ConditionType, + condition: _Condition, conditions: _Conditional[_Conditions], /, ) -> None: @@ -1710,7 +1705,7 @@ def condition( *, empty: Optional[bool] = ..., **kwargs: Any, -) -> dict[str, _ConditionType | Any]: ... +) -> _Conditional[_Condition]: ... @overload def condition( predicate: _PredicateType, @@ -1719,7 +1714,7 @@ def condition( *, empty: Optional[bool] = ..., **kwargs: Any, -) -> dict[str, _ConditionType | Any]: ... +) -> _Conditional[_Condition]: ... @overload def condition( predicate: _PredicateType, if_true: str, if_false: str, **kwargs: Any @@ -1732,7 +1727,7 @@ def condition( *, empty: Optional[bool] = Undefined, **kwargs: Any, -) -> SchemaBase | dict[str, _ConditionType | Any]: +) -> SchemaBase | _Conditional[_Condition]: """ A conditional attribute or encoding.