-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils_plot.py
224 lines (180 loc) · 8.26 KB
/
utils_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def interpolate(dataset, data_name, target_shape, d, inds):
tmp = {}
tmp[data_name] = []
var = dataset[data_name].values.flatten()[inds]
var.shape = target_shape.shape
#var = block_reduce(
# var[91:, 8:-8], block_size=(self.N_res, self.N_res), func=np.mean
# )
tmp[data_name].append(var)
return tmp
def lon_lat_to_cartesian(lon, lat):
# WGS 84 reference coordinate system parameters
A = 6378.137 # major axis [km]
E2 = 6.69437999014e-3 # eccentricity squared
lon_rad = np.radians(lon)
lat_rad = np.radians(lat)
# convert to cartesian coordinates
r_n = A / (np.sqrt(1 - E2 * (np.sin(lat_rad) ** 2)))
x = r_n * np.cos(lat_rad) * np.cos(lon_rad)
y = r_n * np.cos(lat_rad) * np.sin(lon_rad)
z = r_n * (1 - E2) * np.sin(lat_rad)
return x, y, z
class _HeatMapper2(object):
"""Draw a heatmap plot of a matrix with nice labels and colormaps."""
def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
annot_kws, cellsize, cellsize_vmax,
cbar, cbar_kws,
xticklabels=True, yticklabels=True, mask=None, ax_kws=None, rect_kws=None):
"""Initialize the plotting object."""
# We always want to have a DataFrame with semantic information
# and an ndarray to pass to matplotlib
if isinstance(data, pd.DataFrame):
plot_data = data.values
else:
plot_data = np.asarray(data)
data = pd.DataFrame(plot_data)
# Validate the mask and convet to DataFrame
# Get good names for the rows and columns
xtickevery = 1
x_axis_labels = ['1','2','3','4','5','6','7','8','9','10','11','12'] # labels for x-axis
y_axis_labels = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'] # labels for y-axis
# Get the positions and used label for the ticks
nx, ny = data.T.shape
self.xticks = 0.5 + np.arange(12)
self.xticklabels = x_axis_labels
self.yticks = 0.5 + np.arange(12)
self.yticklabels = y_axis_labels
# Get good names for the axis labels
self.ylabel = 'Initialization time'
self.xlabel = 'Lead time (months)'
# Determine good default values for cell size
self._determine_cellsize_params(plot_data, cellsize, cellsize_vmax)
# Save other attributes to the object
self.data = data
self.plot_data = plot_data
# Sort out the annotations
if annot is None:
annot = False
annot_data = None
elif isinstance(annot, bool):
if annot:
annot_data = plot_data
else:
annot_data = None
else:
try:
annot_data = annot.values
except AttributeError:
annot_data = annot
if annot.shape != plot_data.shape:
raise ValueError('Data supplied to "annot" must be the same '
'shape as the data to plot.')
annot = True
self.annot = annot
self.annot_data = annot_data
self.cmap = cmap
self.fmt = fmt
self.annot_kws = {} if annot_kws is None else annot_kws
self.annot_kws.setdefault('color', "black")
self.annot_kws.setdefault('ha', "center")
self.annot_kws.setdefault('va', "center")
self.cbar = cbar
self.cbar_kws = {} if cbar_kws is None else cbar_kws
self.cbar_kws.setdefault('ticks', mpl.ticker.MaxNLocator(6))
self.ax_kws = {} if ax_kws is None else ax_kws
self.rect_kws = {} if rect_kws is None else rect_kws
#self.rect_kws.setdefault('edgecolor', "black")
self.vmax = vmax
self.vmin = vmin
self.title = title
def _determine_cellsize_params(self, plot_data, cellsize, cellsize_vmax):
if isinstance(cellsize, pd.DataFrame):
cellsize = cellsize.values
self.cellsize = cellsize
if cellsize_vmax is None:
cellsize_vmax = cellsize.max()
self.cellsize_vmax = cellsize_vmax
def plot(self, data, ax, cax):
"""Draw the heatmap on the provided Axes."""
# Remove all the Axes spines
despine(ax=ax, left=True, bottom=True)
# Draw the heatmap and annotate
height, width = self.plot_data.shape
xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)
#data = self.plot_data.data
array = self.data
cellsize = self.cellsize
# Draw rectangles instead of using pcolormesh
# Might be slower than original heatmap
annot_data = self.annot_data
if not self.annot:
annot_data = np.zeros(self.plot_data.shape)
for x, y, val, s, an_val in zip(xpos.flat, ypos.flat, data.flat, cellsize.flat, annot_data.flat):
vv = (val - self.vmin) / (self.vmax - self.vmin)
size = np.clip(s / self.cellsize_vmax, 0.1, 1.0)
color = self.cmap(vv)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, **self.rect_kws)
ax.add_patch(rect)
if self.annot:
annotation = ("{}").format(str(an_val)[:4])
text = ax.text(x, y, annotation, **self.annot_kws)
# add edge to text
text_luminance = relative_luminance(text.get_color())
text_edge_color = ".15" if text_luminance > .408 else "w"
text.set_path_effects([mpl.patheffects.withStroke(linewidth=1, foreground=text_edge_color)])
# Set the axis limits
ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))
# Set other attributes
ax.set(**self.ax_kws)
if self.cbar:
norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
scalar_mappable = mpl.cm.ScalarMappable(cmap=self.cmap, norm=norm)
scalar_mappable.set_array(self.plot_data.data)
cb = ax.figure.colorbar(scalar_mappable, cax, ax, **self.cbar_kws)
cb.outline.set_linewidth(0)
font_size = 18 # Adjust as appropriate.
cb.ax.tick_params(labelsize=font_size)
# if kws.get('rasterized', False):
# cb.solids.set_rasterized(True)
# Add row and column labels
self.xticks = np.arange(12) + 0.5
self.xticklabels = x_axis_labels
self.yticks = np.arange(12) + 0.5
self.yticklabels = y_axis_labels
ax.set(xticks=self.xticks, yticks=self.yticks)
xtl = ax.set_xticklabels(self.xticklabels,fontsize = 18)
ytl = ax.set_yticklabels(self.yticklabels, rotation="vertical",fontsize = 18)
# Possibly rotate them if they overlap
ax.figure.draw(ax.figure.canvas.get_renderer())
if axis_ticklabels_overlap(xtl):
plt.setp(xtl, rotation="vertical")
if axis_ticklabels_overlap(ytl):
plt.setp(ytl, rotation="horizontal")
# Add the axis labels
ax.set_title(self.title, fontsize = 20)
ax.set_xlabel(self.xlabel, fontsize = 20)
ax.set_ylabel(self.ylabel, fontsize = 20)
# Invert the y axis to show the plot in matrix form
ax.invert_yaxis()
def heatmap2(array, title, vmin=None, vmax=None, cmap=None, center=None, robust=False,
annot=None, fmt=".2g", annot_kws=None,
cellsize=None, cellsize_vmax=None,
cbar=True, cbar_kws=None, cbar_ax=None,
square=False, xticklabels="auto", yticklabels="auto",
mask=None, ax=None, ax_kws=None, rect_kws=None):
# Initialize the plotter object
plotter = _HeatMapper2(array, vmin, vmax, cmap, center, robust,
annot, fmt, annot_kws,
cellsize, cellsize_vmax,
cbar, cbar_kws, xticklabels,
yticklabels, mask, ax_kws, rect_kws)
# Draw the plot and return the Axes
if ax is None:
ax = plt.gca()
if square:
ax.set_aspect("equal")
# delete grid
ax.grid(alpha = 0.3, ls = "--", lw = 1)
plotter.plot(array, ax, cbar_ax)
return ax