Skip to content

Commit

Permalink
Add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cphyc committed Sep 28, 2023
1 parent 5c2888f commit 7fa9edd
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions labellines/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.container import ErrorbarContainer
from matplotlib.dates import DateConverter, num2date
from matplotlib.lines import Line2D
from more_itertools import always_iterable

from .line_label import LineLabel
Expand All @@ -12,16 +14,16 @@

# Label line with line2D label data
def labelLine(
line,
line: Line2D,
x,
label=None,
align=None,
drop_label=False,
yoffset=0,
yoffset_logspace=False,
outline_color="auto",
outline_width=8,
rotation=None,
label: Optional[str] = None,
align: Optional[bool] = None,
drop_label: bool = False,
yoffset: float = 0,
yoffset_logspace: bool = False,
outline_color: str = "auto",
outline_width: float = 8,
rotation: Optional[float] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -91,15 +93,15 @@ def labelLine(


def labelLines(
lines=None,
align=None,
xvals=None,
drop_label=False,
shrink_factor=0.05,
yoffsets=0,
outline_color="auto",
outline_width=5,
rotation=None,
lines: Optional[List[Line2D]] = None,
align: Optional[bool] = None,
xvals: Optional[Union[tuple[float, float], list[float]]] = None,
drop_label: bool = False,
shrink_factor: float = 0.05,
yoffsets: Union[float, list[float]] = 0,
outline_color: str = "auto",
outline_width: float = 5,
rotation: Optional[bool] = None,
**kwargs,
):
"""Label all lines with their respective legends.
Expand Down Expand Up @@ -196,7 +198,7 @@ def labelLines(
for i, line in enumerate(all_lines):
xdata, _ = normalize_xydata(line)
minx, maxx = min(xdata), max(xdata)
for j, xv in enumerate(xvals):
for j, xv in enumerate(xvals): # type: ignore
ok_matrix[i, j] = minx < xv < maxx

# If some xvals do not fall in their corresponding line,
Expand All @@ -208,14 +210,14 @@ def labelLines(
order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0])

# Now reorder the xvalues
old_xvals = xvals.copy()
xvals[order] = old_xvals
old_xvals = xvals.copy() # type: ignore
xvals[order] = old_xvals # type: ignore
else:
xvals = list(always_iterable(xvals)) # force the creation of a copy

lab_lines, labels = [], []
# Take only the lines which have labels other than the default ones
for i, (line, xv) in enumerate(zip(all_lines, xvals)):
for i, (line, xv) in enumerate(zip(all_lines, xvals)): # type: ignore
label = all_labels[all_lines.index(line)]
lab_lines.append(line)
labels.append(label)
Expand All @@ -233,18 +235,23 @@ def labelLines(
stacklevel=1,
)
new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9
xvals[i] = new_xv
xvals[i] = new_xv # type: ignore

# Convert float values back to datetime in case of datetime axis
if isinstance(ax.xaxis.converter, DateConverter):
xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals]
xvals = [
num2date(x).replace(tzinfo=ax.xaxis.get_units())
for x in xvals # type: ignore
]

txts = []
try:
yoffsets = [float(yoffsets)] * len(all_lines)
yoffsets = [float(yoffsets)] * len(all_lines) # type: ignore
except TypeError:
pass
for line, x, yoffset, label in zip(lab_lines, xvals, yoffsets, labels):
for line, x, yoffset, label in zip(
lab_lines, xvals, yoffsets, labels # type: ignore
):
txts.append(
labelLine(
line,
Expand Down

0 comments on commit 7fa9edd

Please sign in to comment.