Skip to content

Commit

Permalink
refactor(typing): Factor out _ConditionType, use _Condition (`Typ…
Browse files Browse the repository at this point in the history
…edDict`)

- Adds some more consistency between `condition` and `when-then-otherwise`
- 1 less thing to think about
  • Loading branch information
dangotbanned committed Sep 16, 2024
1 parent d20109e commit 0a6d599
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
_PredicateType: TypeAlias = Union[
Parameter,
core.Expr,
Map,
"_ConditionExtra",
_TestPredicateType,
_expr_core.OperatorMixin,
]
Expand All @@ -538,12 +538,6 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
```
"""

_ConditionType: TypeAlias = t.Dict[str, Union[_TestPredicateType, Any]]
"""Intermediate type representing a converted `_PredicateType`.
Prior to parsing any `_StatementType`.
"""

_LiteralValue: TypeAlias = Union[str, bool, float, int]
"""Primitive python value types."""

Expand All @@ -560,15 +554,15 @@ def _is_test_predicate(obj: Any) -> TypeIs[_TestPredicateType]:
return isinstance(obj, (str, _expr_core.Expression, core.PredicateComposition))


def _get_predicate_expr(p: Parameter) -> Optional[str | SchemaBase]:
def _get_predicate_expr(p: Parameter) -> Optional[_TestPredicateType]:
# https://vega.github.io/vega-lite/docs/predicate.html
return getattr(p.param, "expr", Undefined)


def _predicate_to_condition(
predicate: _PredicateType, *, empty: Optional[bool] = Undefined
) -> _ConditionType:
condition: _ConditionType
) -> _Condition:
condition: _Condition
if isinstance(predicate, Parameter):
predicate_expr = _get_predicate_expr(predicate)
if predicate.param_type == "selection" or utils.is_undefined(predicate_expr):
Expand All @@ -595,12 +589,12 @@ def _predicate_to_condition(


def _condition_to_selection(
condition: _ConditionType,
condition: _Condition,
if_true: _StatementType,
if_false: _StatementType,
**kwargs: Any,
) -> SchemaBase | dict[str, _ConditionType | Any]:
selection: SchemaBase | dict[str, _ConditionType | Any]
) -> SchemaBase | _Conditional[_Condition]:
selection: SchemaBase | _Conditional[_Condition]
if isinstance(if_true, SchemaBase):
if_true = if_true.to_dict()
elif isinstance(if_true, str):
Expand All @@ -614,17 +608,18 @@ def _condition_to_selection(
else:
if_true = utils.parse_shorthand(if_true)
if_true.update(kwargs)
condition.update(if_true)
cond_mutable: Any = dict(condition)
cond_mutable.update(if_true)
if isinstance(if_false, SchemaBase):
# For the selection, the channel definitions all allow selections
# already. So use this SchemaBase wrapper if possible.
selection = if_false.copy()
selection.condition = condition
selection.condition = cond_mutable
elif isinstance(if_false, (str, dict)):
if isinstance(if_false, str):
if_false = utils.parse_shorthand(if_false)
if_false.update(kwargs)
selection = dict(condition=condition, **if_false)
selection = _Conditional(condition=cond_mutable, **if_false) # type: ignore[typeddict-item]
else:
raise TypeError(if_false)
return selection
Expand Down Expand Up @@ -785,7 +780,7 @@ def _parse_when(
*more_predicates: _ComposablePredicateType,
empty: Optional[bool],
**constraints: _FieldEqualType,
) -> _ConditionType:
) -> _Condition:
composed: _PredicateType
if utils.is_undefined(predicate):
if more_predicates or constraints:
Expand Down Expand Up @@ -842,7 +837,7 @@ def _parse_otherwise(

class _BaseWhen(Protocol):
# NOTE: Temporary solution to non-SchemaBase copy
_condition: _ConditionType
_condition: _Condition

def _when_then(
self, statement: _StatementType, kwds: dict[str, Any], /
Expand All @@ -866,7 +861,7 @@ class When(_BaseWhen):
`polars.when <https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.when.html>`__
"""

def __init__(self, condition: _ConditionType, /) -> None:
def __init__(self, condition: _Condition, /) -> None:
self._condition = condition

def __repr__(self) -> str:
Expand Down Expand Up @@ -1129,7 +1124,7 @@ class ChainedWhen(_BaseWhen):

def __init__(
self,
condition: _ConditionType,
condition: _Condition,
conditions: _Conditional[_Conditions],
/,
) -> None:
Expand Down Expand Up @@ -1710,7 +1705,7 @@ def condition(
*,
empty: Optional[bool] = ...,
**kwargs: Any,
) -> dict[str, _ConditionType | Any]: ...
) -> _Conditional[_Condition]: ...
@overload
def condition(
predicate: _PredicateType,
Expand All @@ -1719,7 +1714,7 @@ def condition(
*,
empty: Optional[bool] = ...,
**kwargs: Any,
) -> dict[str, _ConditionType | Any]: ...
) -> _Conditional[_Condition]: ...
@overload
def condition(
predicate: _PredicateType, if_true: str, if_false: str, **kwargs: Any
Expand All @@ -1732,7 +1727,7 @@ def condition(
*,
empty: Optional[bool] = Undefined,
**kwargs: Any,
) -> SchemaBase | dict[str, _ConditionType | Any]:
) -> SchemaBase | _Conditional[_Condition]:
"""
A conditional attribute or encoding.
Expand Down

0 comments on commit 0a6d599

Please sign in to comment.