Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implemented y-axis-based labeling #136

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ for a in A:
ax.plot(X, loglaplace(4).pdf(a * X), label=str(a))

xvals = [0.8, 0.55, 0.22, 0.104, 0.045]
labelLines(ax.get_lines(), align=False, xvals=xvals, color="k")
labelLines(ax.get_lines(), align=False, vals=xvals, color="k")

ax = axes[3]
for a in A:
Expand Down Expand Up @@ -77,7 +77,7 @@ ax = axes[5]
for a in A:
ax.semilogx(X, chi2(5).pdf(a * X), label=str(a))

labelLines(ax.get_lines(), xvals=(0.1, 1), zorder=2.5)
labelLines(ax.get_lines(), vals=(0.1, 1), zorder=2.5)

fig.show()
```
Expand Down
276 changes: 172 additions & 104 deletions example/matplotlib_label_lines.ipynb

Large diffs are not rendered by default.

127 changes: 75 additions & 52 deletions labellines/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# Label line with line2D label data
def labelLine(
line,
x,
val,
axis="x",
label=None,
align=True,
drop_label=False,
yoffset=0,
yoffset_logspace=False,
offset=0,
offset_logspace=False,
outline_color="auto",
outline_width=8,
**kwargs,
Expand All @@ -30,17 +31,19 @@ def labelLine(
----------
line : matplotlib.lines.Line
The line holding the label
x : number
val : number
The location in data unit of the label
axis : "x" | "y"
Reference axis for `val`.
Comment on lines +36 to +37
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am overall happy with the val/axis combination. Two notes though

  • we need to keep the old (yoffset, yoffset_logspace) for backward compatibility
  • add a note in the docstring below to explain how to use the axis kwa?

label : string, optional
The label to set. This is inferred from the line by default
drop_label : bool, optional
If True, the label is consumed by the function so that subsequent
calls to e.g. legend do not use it anymore.
yoffset : double, optional
offset : double, optional
Space to add to label's y position
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Space to add to label's y position
Space to add to label's x/y position

yoffset_logspace : bool, optional
If True, then yoffset will be added to the label's y position in
offset_logspace : bool, optional
If True, then offset will be added to the label's y position in
log10 space
outline_color : None | "auto" | color
Colour of the outline. If set to "auto", use the background color.
Expand All @@ -54,11 +57,12 @@ def labelLine(
try:
txt = LineLabel(
line,
x,
val,
axis,
label=label,
align=align,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
offset=offset,
offset_logspace=offset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
Expand Down Expand Up @@ -86,10 +90,11 @@ def labelLine(
def labelLines(
lines=None,
align=True,
xvals=None,
vals=None,
axis=None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have the same signature as above (i.e. axis defaults to x).

drop_label=False,
shrink_factor=0.05,
yoffsets=0,
offsets=0,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about maintaining backward compatibility.

outline_color="auto",
outline_width=5,
**kwargs,
Expand All @@ -103,17 +108,20 @@ def labelLines(
align : boolean, optional
If True, the label will be aligned with the slope of the line
at the location of the label. If False, they will be horizontal.
xvals : (xfirst, xlast) or array of float, optional
vals : (first, last) or array of float, optional
The location of the labels. If a tuple, the labels will be
evenly spaced between xfirst and xlast (in the axis units).
evenly spaced between first and last (in the axis units).
axis : None | "x" | "y", optional
Reference axis for the `vals`.
Comment on lines +114 to +115
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you update the signature, don't forget to update the docstring as well ;)

drop_label : bool, optional
If True, the label is consumed by the function so that subsequent
calls to e.g. legend do not use it anymore.
shrink_factor : double, optional
Relative distance from the edges to place closest labels. Defaults to 0.05.
yoffsets : number or list, optional.
offsets : number or list, optional.
Distance relative to the line when positioning the labels. If given a number,
the same value is used for all lines.
the same value is used for all lines. It refers to the *other* axis
(i.e. to y if axis=="x")
outline_color : None | "auto" | color
Colour of the outline. If set to "auto", use the background color.
If set to None, do not draw an outline.
Expand All @@ -122,11 +130,18 @@ def labelLines(
kwargs : dict, optional
Optional arguments passed to ax.text
"""

if lines:
ax = lines[0].axes
else:
ax = plt.gca()

if axis == "y":
yaxis = True
else:
axis = "x"
yaxis = False
Comment on lines +139 to +143
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the check below will be performed at different locations (here and at line 174), I'd suggest factorising it somehow.
I would also recommend explicitly testing the axis to be either x, y or otherwise fail with an error.

Suggested change
if axis == "y":
yaxis = True
else:
axis = "x"
yaxis = False
if axis == "y":
yaxis = True
elif axis == "x":
yaxis = False
else:
raise ValueError(r"Got an invalid axis {axis}, expected 'x' or 'y'")


handles, labels_of_handles = ax.get_legend_handles_labels()

all_lines, all_labels = [], []
Expand Down Expand Up @@ -156,85 +171,93 @@ def labelLines(

# In case no x location was provided, we need to use some heuristics
# to generate them.
if xvals is None:
xvals = ax.get_xlim()
xvals_rng = xvals[1] - xvals[0]
shrinkage = xvals_rng * shrink_factor
xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage)

if isinstance(xvals, tuple) and len(xvals) == 2:
xmin, xmax = xvals
if vals is None:
if yaxis:
vals = ax.get_ylim()
else:
vals = ax.get_xlim()
vals_rng = vals[1] - vals[0]
shrinkage = vals_rng * shrink_factor
vals = (vals[0] + shrinkage, vals[1] - shrinkage)

if isinstance(vals, tuple) and len(vals) == 2:
vmin, vmax = vals
xscale = ax.get_xscale()
if xscale == "log":
xvals = np.logspace(np.log10(xmin), np.log10(xmax), len(all_lines) + 2)[
1:-1
]
vals = np.logspace(np.log10(vmin), np.log10(vmax), len(all_lines) + 2)[1:-1]
else:
xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1]
vals = np.linspace(vmin, vmax, len(all_lines) + 2)[1:-1]
Comment on lines 185 to +189
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't implement the logic for the y axis.


# Build matrix line -> xvalue
# Build matrix line -> value
ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool)

for i, line in enumerate(all_lines):
xdata, _ = normalize_xydata(line)
minx, maxx = min(xdata), max(xdata)
for j, xv in enumerate(xvals):
ok_matrix[i, j] = minx < xv < maxx
if yaxis:
_, data = normalize_xydata(line)
else:
data, _ = normalize_xydata(line)
minv, maxv = min(data), max(data)
for j, val in enumerate(vals):
ok_matrix[i, j] = minv < val < maxv

# If some xvals do not fall in their corresponding line,
# If some vals do not fall in their corresponding line,
# find a better matching using maximum bipartite matching.
if not np.all(np.diag(ok_matrix)):
order = maximum_bipartite_matching(ok_matrix)

# The maximum match may miss a few points, let's add them back
order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0])

# Now reorder the xvalues
old_xvals = xvals.copy()
xvals[order] = old_xvals
# Now reorder the values
old_vals = vals.copy()
vals[order] = old_vals
else:
xvals = list(always_iterable(xvals)) # force the creation of a copy
vals = list(always_iterable(vals)) # force the creation of a copy

lab_lines, labels = [], []
# Take only the lines which have labels other than the default ones
for i, (line, xv) in enumerate(zip(all_lines, xvals)):
for i, (line, val) in enumerate(zip(all_lines, vals)):
label = all_labels[all_lines.index(line)]
lab_lines.append(line)
labels.append(label)

# Move xlabel if it is outside valid range
xdata, _ = normalize_xydata(line)
if not (min(xdata) <= xv <= max(xdata)):
# Move xlabel/ylabel if it is outside valid range
if yaxis:
_, data = normalize_xydata(line)
else:
data, _ = normalize_xydata(line)
if not (min(data) <= val <= max(data)):
warnings.warn(
(
"The value at position {} in `xvals` is outside the range of its "
"associated line (xmin={}, xmax={}, xval={}). Clipping it "
"The value at position {} in `vals` is outside the range of its "
"associated line (vmin={}, vmax={}, val={}). Clipping it "
"into the allowed range."
).format(i, min(xdata), max(xdata), xv),
).format(i, min(data), max(data), val),
UserWarning,
stacklevel=1,
)
new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9
xvals[i] = new_xv
new_val = min(data) + (max(data) - min(data)) * 0.9
vals[i] = new_val

# Convert float values back to datetime in case of datetime axis
if isinstance(ax.xaxis.converter, DateConverter):
xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals]
vals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in vals]

txts = []
try:
yoffsets = [float(yoffsets)] * len(all_lines)
offsets = [float(offsets)] * len(all_lines)
except TypeError:
pass
for line, x, yoffset, label in zip(lab_lines, xvals, yoffsets, labels):
for line, val, offset, label in zip(lab_lines, vals, offsets, labels):
txts.append(
labelLine(
line,
x,
val,
axis,
label=label,
align=align,
drop_label=drop_label,
yoffset=yoffset,
offset=offset,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
Expand Down
Loading