diff --git a/doc/changelog.qmd b/doc/changelog.qmd index 4da468bc6..f44b4bc9d 100644 --- a/doc/changelog.qmd +++ b/doc/changelog.qmd @@ -63,6 +63,9 @@ title: Changelog - Fixed using `facet_grid` with a column named `key`. ({{< issue 734 >}}) +- Fixed using legend when using an identity scale and reordering the + breaks. ({{< issue 735 >}}) + ### Enhancements - All `__all__` variables are explicitly assigned to help static typecheckers diff --git a/plotnine/guides/guide_legend.py b/plotnine/guides/guide_legend.py index 086bea1fd..9e5ec80c3 100644 --- a/plotnine/guides/guide_legend.py +++ b/plotnine/guides/guide_legend.py @@ -69,7 +69,7 @@ def train(self, scale, aesthetic=None): scale name is one of the aesthetics: `x`, `y`, `color`, `fill`, `size`, `shape`, `alpha`, `stroke`. - Returns this guide if trainning is successful and None + Returns this guide if training is successful and None if it fails """ if aesthetic is None: diff --git a/plotnine/scales/scale_discrete.py b/plotnine/scales/scale_discrete.py index a334444b5..123ececb3 100644 --- a/plotnine/scales/scale_discrete.py +++ b/plotnine/scales/scale_discrete.py @@ -237,21 +237,20 @@ def get_breaks( The form is suitable for use by the guides e.g. ['fair', 'good', 'very good', 'premium', 'ideal'] """ - if limits is None: - limits = self.limits - if self.is_empty(): return [] - if self.breaks is True: - breaks = list(limits) - elif self.breaks in (False, None): + if limits is None: + limits = self.limits + + if self.breaks in (None, False): breaks = [] + elif self.breaks is True: + breaks = list(limits) elif callable(self.breaks): breaks = self.breaks(limits) else: - _wanted_breaks = set(self.breaks) - breaks = [l for l in limits if l in _wanted_breaks] + breaks = list(self.breaks) return breaks @@ -264,10 +263,8 @@ def get_bounded_breaks( if limits is None: limits = self.limits - lookup = set(limits) - breaks = self.get_breaks() - strict_breaks = [b for b in breaks if b in lookup] - return strict_breaks + lookup_limits = set(limits) + return [b for b in self.get_breaks() if b in lookup_limits] def get_labels( self, breaks: Optional[ScaleDiscreteBreaks] = None @@ -291,13 +288,15 @@ def get_labels( return self.labels(breaks) # if a dict is used to rename some labels elif isinstance(self.labels, dict): - labels = [ + return [ str(self.labels[b]) if b in self.labels else str(b) for b in breaks ] - return labels else: - return self.labels + # Return the labels in the order that they match with + # the breaks. + label_lookup = dict(zip(self.get_breaks(), self.labels)) + return [label_lookup[b] for b in breaks] def transform_df(self, df: pd.DataFrame) -> pd.DataFrame: """ diff --git a/tests/baseline_images/test_scale_internals/test_legend_ordering_with_identity_scale.png b/tests/baseline_images/test_scale_internals/test_legend_ordering_with_identity_scale.png new file mode 100644 index 000000000..421709c82 Binary files /dev/null and b/tests/baseline_images/test_scale_internals/test_legend_ordering_with_identity_scale.png differ diff --git a/tests/test_scale_internals.py b/tests/test_scale_internals.py index 4a32fe818..05d1ea7c2 100644 --- a/tests/test_scale_internals.py +++ b/tests/test_scale_internals.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +import plotnine as p9 from plotnine import ( aes, annotate, @@ -710,6 +711,26 @@ def test_legend_ordering_added_scales(): assert p == "legend_ordering_added_scales" +def test_legend_ordering_with_identity_scale(): + data = pd.DataFrame( + { + "x": [1, 2, 3, 4], + "y": [1, 2, 3, 4], + "color": ["blue", "blue", "red", "red"], + } + ) + + p = ( + ggplot(data, aes("x", "y", color="color")) + + geom_point() + + p9.scale_color_identity( + breaks=["red", "blue"], labels=["Red", "Blue"], guide="legend" + ) + ) + + assert p == "test_legend_ordering_with_identity_scale" + + def test_breaks_and_labels_outside_of_limits(): data = pd.DataFrame({"x": range(5, 11), "y": range(5, 11)}) p = (