forked from cphyc/matplotlib-label-lines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LabelLines_Arrow.py
166 lines (131 loc) · 5.19 KB
/
LabelLines_Arrow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from math import atan2, degrees
import numpy as np
from matplotlib.dates import date2num
from matplotlib import textpath
from datetime import datetime
# Label line with line2D label data
def labelLine(line, x, label=None, align=True, arrow=False, arrow_scale=[0.25,0.3], **kwargs):
'''Label a single matplotlib line at position x
Parameters
----------
line : matplotlib.lines.Line
The line holding the label
x : number
The location in data unit of the label
label : string, optional
The label to set. This is inferred from the line by default
arrow : bool, optional
Draws an arrow below the label
arrow_scale : [x,y] scaling of the arrow
kwargs : dict, optional
Optional arguments passed to ax.text
'''
ax = line.axes
xdata = line.get_xdata()
ydata = line.get_ydata()
order = np.argsort(xdata)
xdata = xdata[order]
ydata = ydata[order]
# Convert datetime objects to floats
if isinstance(x, datetime):
x = date2num(x)
xmin, xmax = xdata[0], xdata[-1]
if (x < xmin) or (x > xmax):
raise Exception('x label location is outside data range!')
# Find corresponding y co-ordinate and angle of the
ip = 1
for i in range(len(xdata)):
if x < xdata[i]:
ip = i
break
y = ydata[ip-1] + (ydata[ip]-ydata[ip-1]) * \
(x-xdata[ip-1])/(xdata[ip]-xdata[ip-1])
if not label:
label = line.get_label()
if align:
# Compute the slope
dx = xdata[ip] - xdata[ip-1]
dy = ydata[ip] - ydata[ip-1]
ang = degrees(atan2(dy, dx))
# Transform to screen co-ordinates
pt = np.array([x, y]).reshape((1, 2))
trans_angle = ax.transData.transform_angles(np.array((ang, )), pt)[0]
else:
trans_angle = 0
# Set a bunch of keyword arguments
if 'color' not in kwargs:
kwargs['color'] = line.get_color()
if ('horizontalalignment' not in kwargs) and ('ha' not in kwargs):
kwargs['ha'] = 'center'
if ('verticalalignment' not in kwargs) and ('va' not in kwargs):
kwargs['va'] = 'center'
if 'backgroundcolor' not in kwargs:
kwargs['backgroundcolor'] = ax.get_facecolor()
if 'clip_on' not in kwargs:
kwargs['clip_on'] = True
if 'zorder' not in kwargs:
kwargs['zorder'] = 2.5
ax.text(x, y, label, rotation=trans_angle, **kwargs)
if arrow:
# get axis position
if ax.yaxis.get_label_position() == 'right':
x_dir = 1
if ax.yaxis.get_label_position() == 'left':
x_dir = -1
# get height of plot
y_height = ax.get_yticks()[1]-ax.get_yticks()[0]
y_arrow = y-arrow_scale[1]*y_height
# get xscaling
ax_xscale = ax.get_xscale()
# scale arrow for linear
if ax_xscale == 'linear':
x_width = ax.get_xticks()[1]-ax.get_xticks()[0]
x_arrow_start = x-(x_dir*arrow_scale[0]*x_width)
x_arrow_end = x+(x_dir*arrow_scale[0]*x_width)
# scale arrow for log
if ax_xscale == 'log':
x_arrow_start = x-(x_dir*arrow_scale[0]*x)
x_arrow_end = x+(x_dir*arrow_scale[0]*x)
# draw arrow
linecolor = line.get_color()
ax.annotate("", xytext=(x_arrow_start, y_arrow),
xy=(x_arrow_end, y_arrow),arrowprops=dict(lw=2, arrowstyle='->', color=linecolor), **kwargs)
def labelLines(lines, align=True, xvals=None, draw_arrow=False, arrow_scale=[0.25,0.3], **kwargs):
'''Label all lines with their respective legends.
Parameters
----------
lines : list of matplotlib lines
The lines to label
align : boolean, optional
If True, the label will be aligned with the slope of the line
at the location of the label. If False, they will be horizontal.
xvals : (xfirst, xlast) or array of float, optional
The location of the labels. If a tuple, the labels will be
evenly spaced between xfirst and xlast (in the axis units).
draw_arrow : boolean, optional
If True, an arrow will be drawn below the label. This arrow is
pointing to the axis of the specific matplotlib line
arrow_scale : [x,y] scaling of the arrow
kwargs : dict, optional
Optional arguments passed to ax.text
'''
ax = lines[0].axes
labLines = []
labels = []
# Take only the lines which have labels other than the default ones
for line in lines:
label = line.get_label()
if "_line" not in label:
labLines.append(line)
labels.append(label)
if xvals is None:
xvals = ax.get_xlim() # set axis limits as annotation limits, xvals now a tuple
if type(xvals) == tuple:
xmin, xmax = xvals
xscale = ax.get_xscale()
if xscale == "log":
xvals = np.logspace(np.log10(xmin), np.log10(xmax), len(labLines)+2)[1:-1]
else:
xvals = np.linspace(xmin, xmax, len(labLines)+2)[1:-1]
for line, x, label in zip(labLines, xvals, labels):
labelLine(line, x, label, align, arrow=draw_arrow, arrow_scale=arrow_scale, **kwargs)