Skip to content

Commit

Permalink
Make element_text accept a sequence of values
Browse files Browse the repository at this point in the history
closes #724
  • Loading branch information
has2k1 committed Jan 8, 2024
1 parent 3d2edb7 commit ab46610
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 36 deletions.
4 changes: 4 additions & 0 deletions doc/changelog.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ title: Changelog

to control the x & y tick padding.

- Some parameters in [](:class:`~plotnine.themes.element_text`) can now
accept lists/tuples to set the values on individual text objects.
({{< issue 724 >}})

### Bug Fixes

- Fixed handling of minor breaks in
Expand Down
24 changes: 17 additions & 7 deletions plotnine/themes/elements/element_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .margin import Margin

if TYPE_CHECKING:
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Sequence

from plotnine.typing import Theme, TupleFloat3, TupleFloat4

Expand Down Expand Up @@ -63,15 +63,25 @@ class element_text(element_base):
def __init__(
self,
family: Optional[str | list[str]] = None,
style: Optional[str] = None,
weight: Optional[int | str] = None,
color: Optional[str | TupleFloat3 | TupleFloat4] = None,
size: Optional[float] = None,
style: Optional[str | Sequence[str]] = None,
weight: Optional[int | str | Sequence[int | str]] = None,
color: Optional[
str
| TupleFloat3
| TupleFloat4
| Sequence[str | TupleFloat3 | TupleFloat4]
] = None,
size: Optional[float | Sequence[float]] = None,
ha: Optional[Literal["center", "left", "right"]] = None,
va: Optional[Literal["center", "top", "bottom", "baseline"]] = None,
rotation: Optional[float] = None,
rotation: Optional[Literal["vertical", "horizontal"] | float] = None,
linespacing: Optional[float] = None,
backgroundcolor: Optional[str | TupleFloat3 | TupleFloat4] = None,
backgroundcolor: Optional[
str
| TupleFloat3
| TupleFloat4
| Sequence[str | TupleFloat3 | TupleFloat4]
] = None,
margin: Optional[
dict[Literal["t", "b", "l", "r", "units"], Any]
] = None,
Expand Down
82 changes: 53 additions & 29 deletions plotnine/themes/themeable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np

from .._utils import to_rgba
from .._utils.registry import RegistryHierarchyMeta
from ..exceptions import PlotnineError
Expand All @@ -23,9 +25,10 @@

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any, Type
from typing import Any, Sequence, Type

from matplotlib.patches import Patch
from matplotlib.text import Text

from plotnine.typing import Axes, Figure, Theme

Expand Down Expand Up @@ -372,6 +375,46 @@ def _blankout_rect(rect: Patch):
rect.set_linewidth(0)


class texts_themeable(themeable):
"""
Base class for themeables that modify lists of Text
Where possible, setting the properties with a sequence
of values.
e.g.
theme(axis_text_x=element_text(color=("red", "green", "blue")))
The number of values in the sequence must match the number of
text objects in the figure.
"""

def set(self, texts: Sequence[Text]):
properties = self.properties.copy()
with suppress(KeyError):
del properties["margin"]

n = len(texts)
seq_properties = {}
for name, value in properties.items():
if (
isinstance(value, (list, tuple, np.ndarray))
and len(value) == n
):
seq_properties[name] = value

for key in seq_properties:
del properties[key]

for t in texts:
t.set(**properties)

for name, values in seq_properties.items():
for t, value in zip(texts, values):
t.set(**{name: value})


# element_text themeables


Expand Down Expand Up @@ -624,7 +667,7 @@ def blank_figure(self, figure: Figure, targets: dict[str, Any]):
text.set_visible(False)


class strip_text_x(themeable):
class strip_text_x(texts_themeable):
"""
Facet labels along the horizontal axis
Expand All @@ -635,13 +678,8 @@ class strip_text_x(themeable):

def apply_figure(self, figure: Figure, targets: dict[str, Any]):
super().apply_figure(figure, targets)
properties = self.properties.copy()
with suppress(KeyError):
del properties["margin"]
with suppress(KeyError):
texts = targets["strip_text_x"]
for text in texts:
text.set(**properties)
self.set(targets["strip_text_x"])

with suppress(KeyError):
rects = targets["strip_background_x"]
Expand All @@ -661,7 +699,7 @@ def blank_figure(self, figure: Figure, targets: dict[str, Any]):
rect.set_visible(False)


class strip_text_y(themeable):
class strip_text_y(texts_themeable):
"""
Facet labels along the vertical axis
Expand All @@ -672,13 +710,8 @@ class strip_text_y(themeable):

def apply_figure(self, figure: Figure, targets: dict[str, Any]):
super().apply_figure(figure, targets)
properties = self.properties.copy()
with suppress(KeyError):
del properties["margin"]
with suppress(KeyError):
texts = targets["strip_text_y"]
for text in texts:
text.set(**properties)
self.set(targets["strip_text_y"])

with suppress(KeyError):
rects = targets["strip_background_y"]
Expand Down Expand Up @@ -718,7 +751,7 @@ class title(axis_title, legend_title, plot_title, plot_subtitle, plot_caption):
"""


class axis_text_x(themeable):
class axis_text_x(texts_themeable):
"""
x-axis tick labels
Expand All @@ -729,12 +762,7 @@ class axis_text_x(themeable):

def apply_ax(self, ax: Axes):
super().apply_ax(ax)
properties = self.properties.copy()
with suppress(KeyError):
del properties["margin"]
labels = ax.get_xticklabels()
for l in labels:
l.set(**properties)
self.set(ax.get_xticklabels())

def blank_ax(self, ax: Axes):
super().blank_ax(ax)
Expand All @@ -743,7 +771,7 @@ def blank_ax(self, ax: Axes):
)


class axis_text_y(themeable):
class axis_text_y(texts_themeable):
"""
y-axis tick labels
Expand All @@ -754,12 +782,7 @@ class axis_text_y(themeable):

def apply_ax(self, ax: Axes):
super().apply_ax(ax)
properties = self.properties.copy()
with suppress(KeyError):
del properties["margin"]
labels = ax.get_yticklabels()
for l in labels:
l.set(**properties)
self.set(ax.get_yticklabels())

def blank_ax(self, ax: Axes):
super().blank_ax(ax)
Expand Down Expand Up @@ -903,6 +926,7 @@ def apply_ax(self, ax: Axes):
# We split the properties so that set_tick_params keeps
# record of the properties it cares about so that it does
# not undo them. GH703
# https://github.com/matplotlib/matplotlib/issues/26008
tick_params = {}
properties = self.properties.copy()
with suppress(KeyError):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions tests/test_theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ def test_no_ticks():
assert p == "no_ticks"


def test_element_text_with_sequence_values():
p = (
ggplot(mtcars, aes("wt", "mpg"))
+ geom_point()
+ facet_grid(("am", "cyl"))
+ theme(
axis_text=element_text(color="gray"),
axis_text_x=element_text(
color=("red", "green", "blue", "purple"), size=(8, 12, 16, 20)
),
strip_text_x=element_text(color=("black", "brown", "cyan")),
strip_text_y=element_text(color=("teal", "orange")),
)
)
assert p == "element_text_with_sequence_values"


class TestThemes:
g = (
ggplot(mtcars, aes(x="wt", y="mpg", color="factor(gear)"))
Expand Down

0 comments on commit ab46610

Please sign in to comment.