Skip to content

Commit

Permalink
Fix identity scale + legend with breaks reordered
Browse files Browse the repository at this point in the history
fixes #735
  • Loading branch information
has2k1 committed Jan 4, 2024
1 parent de98e6b commit 3e1af2a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 16 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ title: Changelog

- Fixed using `facet_grid` with a column named `key`. ({{< issue 734 >}})

- Fixed using legend when using an identity scale and reordering the
breaks. ({{< issue 735 >}})

### Enhancements

- All `__all__` variables are explicitly assigned to help static typecheckers
Expand Down
2 changes: 1 addition & 1 deletion plotnine/guides/guide_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def train(self, scale, aesthetic=None):
scale name is one of the aesthetics: `x`, `y`, `color`,
`fill`, `size`, `shape`, `alpha`, `stroke`.
Returns this guide if trainning is successful and None
Returns this guide if training is successful and None
if it fails
"""
if aesthetic is None:
Expand Down
29 changes: 14 additions & 15 deletions plotnine/scales/scale_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,20 @@ def get_breaks(
The form is suitable for use by the guides e.g.
['fair', 'good', 'very good', 'premium', 'ideal']
"""
if limits is None:
limits = self.limits

if self.is_empty():
return []

if self.breaks is True:
breaks = list(limits)
elif self.breaks in (False, None):
if limits is None:
limits = self.limits

if self.breaks in (None, False):
breaks = []
elif self.breaks is True:
breaks = list(limits)
elif callable(self.breaks):
breaks = self.breaks(limits)
else:
_wanted_breaks = set(self.breaks)
breaks = [l for l in limits if l in _wanted_breaks]
breaks = list(self.breaks)

return breaks

Expand All @@ -264,10 +263,8 @@ def get_bounded_breaks(
if limits is None:
limits = self.limits

lookup = set(limits)
breaks = self.get_breaks()
strict_breaks = [b for b in breaks if b in lookup]
return strict_breaks
lookup_limits = set(limits)
return [b for b in self.get_breaks() if b in lookup_limits]

def get_labels(
self, breaks: Optional[ScaleDiscreteBreaks] = None
Expand All @@ -291,13 +288,15 @@ def get_labels(
return self.labels(breaks)
# if a dict is used to rename some labels
elif isinstance(self.labels, dict):
labels = [
return [
str(self.labels[b]) if b in self.labels else str(b)
for b in breaks
]
return labels
else:
return self.labels
# Return the labels in the order that they match with
# the breaks.
label_lookup = dict(zip(self.get_breaks(), self.labels))
return [label_lookup[b] for b in breaks]

def transform_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions tests/test_scale_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

import plotnine as p9
from plotnine import (
aes,
annotate,
Expand Down Expand Up @@ -710,6 +711,26 @@ def test_legend_ordering_added_scales():
assert p == "legend_ordering_added_scales"


def test_legend_ordering_with_identity_scale():
data = pd.DataFrame(
{
"x": [1, 2, 3, 4],
"y": [1, 2, 3, 4],
"color": ["blue", "blue", "red", "red"],
}
)

p = (
ggplot(data, aes("x", "y", color="color"))
+ geom_point()
+ p9.scale_color_identity(
breaks=["red", "blue"], labels=["Red", "Blue"], guide="legend"
)
)

assert p == "test_legend_ordering_with_identity_scale"


def test_breaks_and_labels_outside_of_limits():
data = pd.DataFrame({"x": range(5, 11), "y": range(5, 11)})
p = (
Expand Down

0 comments on commit 3e1af2a

Please sign in to comment.