Skip to content

Commit

Permalink
TYP: Use fewer typealiases for scale attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Jul 31, 2024
1 parent a277bd2 commit bb143cc
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 99 deletions.
6 changes: 2 additions & 4 deletions plotnine/iapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
from plotnine.typing import (
CoordRange,
FloatArrayLike,
ScaleBreaks,
ScaledAestheticsName,
ScaleLimits,
StripPosition,
)

Expand All @@ -41,10 +39,10 @@ class scale_view:
aesthetics: list[ScaledAestheticsName]
name: Optional[str]
# Trained limits of the scale
limits: ScaleLimits
limits: tuple[float, float] | Sequence[str]
# Physical size of scale, including expansions
range: CoordRange
breaks: ScaleBreaks
breaks: Sequence[float] | Sequence[str]
minor_breaks: FloatArrayLike
labels: Sequence[str]

Expand Down
22 changes: 5 additions & 17 deletions plotnine/scales/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
from .range import Range

if TYPE_CHECKING:
from typing import Any
from typing import Any, Sequence

import pandas as pd
from numpy.typing import NDArray

from plotnine.typing import (
ScaleBreaks,
ScaledAestheticsName,
ScaleLabels,
ScaleLabelsUser,
ScaleLimits,
)

from ..iapi import range_view, scale_view
Expand Down Expand Up @@ -288,7 +285,7 @@ def is_empty(self) -> bool:
return self._range.is_empty() and self.limits is None

@property
def final_limits(self) -> ScaleLimits:
def final_limits(self) -> Any:
raise NotImplementedError

def train_df(self, df: pd.DataFrame):
Expand All @@ -312,28 +309,19 @@ def map_df(self, df: pd.DataFrame) -> pd.DataFrame:

return df

def get_labels(
self,
breaks=None, # : Optional[ScaleBreaks]
) -> ScaleLabels:
def get_labels(self, breaks=None) -> Sequence[str]:
"""
Get labels, calculating them if required
"""
raise NotImplementedError

def get_breaks(
self,
limits=None, # : Optional[ScaleLimits]
) -> ScaleBreaks:
def get_breaks(self, limits=None):
"""
Get Breaks
"""
raise NotImplementedError

def get_bounded_breaks(
self,
limits=None, # : Optional[ScaleLimits]
) -> ScaleBreaks:
def get_bounded_breaks(self, limits=None):
"""
Return Breaks that are within the limits
"""
Expand Down
53 changes: 23 additions & 30 deletions plotnine/scales/scale_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Sequence,
Type,
Expand All @@ -18,7 +17,6 @@
from mizani.bounds import censor, expand_range_distinct, rescale, zero_range
from mizani.palettes import identity_pal
from mizani.transforms import trans
from numpy.typing import NDArray # noqa: TCH002

from .._utils import match
from ..exceptions import PlotnineError, PlotnineWarning
Expand All @@ -34,27 +32,22 @@

from plotnine.typing import (
CoordRange,
ScaleLabels,
ScaleMinorBreaksUser,
FloatArrayLike,
TFloatArrayLike,
)

GuideTypeT = TypeVar("GuideTypeT")
AnyArrayLike: TypeAlias = "NDArray[Any] | pd.Series[Any] | Sequence[Any]"
FloatArrayLike: TypeAlias = (
"NDArray[np.float64] | pd.Series[float] | Sequence[float]" # noqa: E501
)
ContinuousPalette: TypeAlias = "Callable[[FloatArrayLike], AnyArrayLike]"
ContinuousBreaks: TypeAlias = Sequence[float]
ContinuousLimits: TypeAlias = tuple[float, float]
ContinuousBreaksUser: TypeAlias = (
bool
| None
| ContinuousBreaks
| Callable[[ContinuousLimits], ContinuousBreaks]
| Sequence[float]
| Callable[[tuple[float, float]], Sequence[float]]
)
MinorBreaksUser: TypeAlias = ContinuousBreaksUser
ContinuousLimitsUser: TypeAlias = (
None | ContinuousLimits | Callable[[ContinuousLimits], ContinuousLimits]
None
| tuple[float, float]
| Callable[[tuple[float, float]], tuple[float, float]]
)
TransUser: TypeAlias = trans | str | Type[trans] | None

Expand Down Expand Up @@ -105,7 +98,7 @@ class scale_continuous(
Major breaks
"""

minor_breaks: ScaleMinorBreaksUser = True
minor_breaks: MinorBreaksUser = True
"""
If a list-like, it is the minor breaks points. If an integer, it is the
number of minor breaks between any set of major breaks.
Expand Down Expand Up @@ -192,7 +185,7 @@ def _make_trans(self) -> trans:
return t

@property
def final_limits(self) -> ContinuousLimits:
def final_limits(self) -> tuple[float, float]:
if self.is_empty():
return (0, 1)

Expand Down Expand Up @@ -310,7 +303,7 @@ def dimension(self, expand=(0, 0, 0, 0), limits=None):

def expand_limits(
self,
limits: ContinuousLimits,
limits: tuple[float, float],
expand: tuple[float, float] | tuple[float, float, float, float],
coord_limits: CoordRange | None,
trans: trans,
Expand Down Expand Up @@ -374,7 +367,7 @@ def palette(self, x):
return identity_pal()(x)

def map(
self, x: FloatArrayLike, limits: Optional[ContinuousLimits] = None
self, x: FloatArrayLike, limits: Optional[tuple[float, float]] = None
) -> FloatArrayLike:
if limits is None:
limits = self.final_limits
Expand All @@ -391,8 +384,8 @@ def map(
return scaled

def get_breaks(
self, limits: Optional[ContinuousLimits] = None
) -> ContinuousBreaks:
self, limits: Optional[tuple[float, float]] = None
) -> Sequence[float]:
"""
Generate breaks for the axis or legend
Expand Down Expand Up @@ -425,7 +418,7 @@ def get_breaks(
# TODO: Fix this type mismatch in mizani with
# a typevar so that type-in = type-out
_tlimits = self._trans.breaks(_limits)
breaks: ContinuousBreaks = _tlimits # pyright: ignore
breaks: Sequence[float] = _tlimits # pyright: ignore
elif zero_range(_limits):
breaks = [_limits[0]]
elif callable(self.breaks):
Expand All @@ -437,8 +430,8 @@ def get_breaks(
return breaks

def get_bounded_breaks(
self, limits: Optional[ContinuousLimits] = None
) -> ContinuousBreaks:
self, limits: Optional[tuple[float, float]] = None
) -> Sequence[float]:
"""
Return Breaks that are within limits
"""
Expand All @@ -450,9 +443,9 @@ def get_bounded_breaks(

def get_minor_breaks(
self,
major: ContinuousBreaks,
limits: Optional[ContinuousLimits] = None,
) -> ContinuousBreaks:
major: Sequence[float],
limits: Optional[tuple[float, float]] = None,
) -> Sequence[float]:
"""
Return minor breaks
"""
Expand All @@ -462,11 +455,11 @@ def get_minor_breaks(
if self.minor_breaks is False or self.minor_breaks is None:
minor_breaks = []
elif self.minor_breaks is True:
minor_breaks: ContinuousBreaks = self._trans.minor_breaks(
minor_breaks: Sequence[float] = self._trans.minor_breaks(
major, limits
) # pyright: ignore
elif isinstance(self.minor_breaks, int):
minor_breaks: ContinuousBreaks = self._trans.minor_breaks(
minor_breaks: Sequence[float] = self._trans.minor_breaks(
major,
limits,
self.minor_breaks, # pyright: ignore
Expand All @@ -482,8 +475,8 @@ def get_minor_breaks(
return minor_breaks

def get_labels(
self, breaks: Optional[ContinuousBreaks] = None
) -> ScaleLabels:
self, breaks: Optional[Sequence[float]] = None
) -> Sequence[str]:
"""
Generate labels for the axis or legend
Expand Down
35 changes: 13 additions & 22 deletions plotnine/scales/scale_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,16 @@

from mizani.transforms import trans

from plotnine.typing import (
CoordRange,
ScaleDiscreteBreaks,
ScaleDiscreteLimits,
ScaleLabels,
)
from plotnine.typing import CoordRange


AnyArrayLike: TypeAlias = "NDArray[Any] | pd.Series[Any] | Sequence[Any]"
DiscretePalette: TypeAlias = "Callable[[int], AnyArrayLike | dict[Any, Any]]"
DiscreteBreaks: TypeAlias = Sequence[str]
DiscreteLimits: TypeAlias = Sequence[str]
DiscreteBreaksUser: TypeAlias = (
bool | None | DiscreteBreaks | Callable[[DiscreteLimits], DiscreteBreaks]
bool | None | Sequence[str] | Callable[[Sequence[str]], Sequence[str]]
)
DiscreteLimitsUser: TypeAlias = (
None | DiscreteLimits | Callable[[DiscreteLimits], DiscreteLimits]
None | Sequence[str] | Callable[[Sequence[str]], Sequence[str]]
)


Expand Down Expand Up @@ -93,7 +86,7 @@ def __post_init__(self):
self._range = RangeDiscrete()

@property
def final_limits(self) -> ScaleDiscreteLimits:
def final_limits(self) -> Sequence[str]:
if self.is_empty():
return ("0", "1")

Expand Down Expand Up @@ -136,7 +129,7 @@ def dimension(self, expand=(0, 0, 0, 0), limits=None):

def expand_limits(
self,
limits: ScaleDiscreteLimits,
limits: Sequence[str],
expand: tuple[float, float] | tuple[float, float, float, float],
coord_limits: tuple[float, float],
trans: trans,
Expand All @@ -160,7 +153,7 @@ def expand_limits(

def view(
self,
limits: Optional[ScaleDiscreteLimits] = None,
limits: Optional[Sequence[str]] = None,
range: Optional[CoordRange] = None,
) -> scale_view:
"""
Expand Down Expand Up @@ -201,9 +194,7 @@ def palette(self, n: int) -> Sequence[Any]:
"""
return none_pal()(n)

def map(
self, x, limits: Optional[ScaleDiscreteLimits] = None
) -> Sequence[Any]:
def map(self, x, limits: Optional[Sequence[str]] = None) -> Sequence[Any]:
"""
Map values in x to a palette
"""
Expand Down Expand Up @@ -248,8 +239,8 @@ def map(
return pal_match

def get_breaks(
self, limits: Optional[ScaleDiscreteLimits] = None
) -> ScaleDiscreteBreaks:
self, limits: Optional[Sequence[str]] = None
) -> Sequence[str]:
"""
Return an ordered list of breaks
Expand All @@ -274,8 +265,8 @@ def get_breaks(
return breaks

def get_bounded_breaks(
self, limits: Optional[ScaleDiscreteLimits] = None
) -> ScaleDiscreteBreaks:
self, limits: Optional[Sequence[str]] = None
) -> Sequence[str]:
"""
Return Breaks that are within limits
"""
Expand All @@ -286,8 +277,8 @@ def get_bounded_breaks(
return [b for b in self.get_breaks() if b in lookup_limits]

def get_labels(
self, breaks: Optional[ScaleDiscreteBreaks] = None
) -> ScaleLabels:
self, breaks: Optional[Sequence[str]] = None
) -> Sequence[str]:
"""
Generate labels for the legend/guide breaks
"""
Expand Down
Loading

0 comments on commit bb143cc

Please sign in to comment.