Skip to content

Commit

Permalink
Added option to show basic statistics in plots of `monitor/multi_data…
Browse files Browse the repository at this point in the history
…sets.py` (#2790)

* Added option to show basic statistics in monitoring diag

* Added description of number in top left corner to caption

* Used weighted R2 instead of unweighted R2
  • Loading branch information
schlunma authored Sep 7, 2022
1 parent bc4a43c commit 752c3f6
Showing 1 changed file with 102 additions and 7 deletions.
109 changes: 102 additions & 7 deletions esmvaltool/diag_scripts/monitor/multi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@
<https://matplotlib.org/stable/gallery/misc/rasterization_demo.html>`_ for
map plots to produce smaller files. This is only relevant for vector
graphics (e.g., ``output_file_type=pdf,svg,ps``).
show_stats: bool, optional (default: True)
Show basic statistics on the plots.
Configuration options for plot type ``profile``
-----------------------------------------------
Expand Down Expand Up @@ -243,6 +245,8 @@
graphics (e.g., ``output_file_type=pdf,svg,ps``).
show_y_minor_ticklabels: bool, optional (default: False)
Show tick labels for the minor ticks on the Y axis.
show_stats: bool, optional (default: True)
Show basic statistics on the plots.
.. hint::
Expand All @@ -260,10 +264,14 @@
import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from iris.analysis.cartography import area_weights
from iris.coord_categorisation import add_year
from iris.coords import AuxCoord
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FormatStrFormatter, NullFormatter
from sklearn.metrics import r2_score

import esmvaltool.diag_scripts.shared.iris_helpers as ih
from esmvaltool.diag_scripts.monitor.monitor_base import MonitorBase
Expand Down Expand Up @@ -320,6 +328,7 @@ def __init__(self, config):
self.plots[plot_type].setdefault('common_cbar', False)
self.plots[plot_type].setdefault('plot_func', 'contourf')
self.plots[plot_type].setdefault('rasterize', True)
self.plots[plot_type].setdefault('show_stats', True)

# Defaults profile plots
if plot_type == 'profile':
Expand Down Expand Up @@ -365,6 +374,66 @@ def _add_colorbar(self, plot_type, plot_left, plot_right, axes_left,
cbar_right.set_label(cbar_label_right, fontsize=fontsize)
cbar_right.ax.tick_params(labelsize=fontsize)

def _add_stats(self, plot_type, axes, dim_coords, cube, ref_cube=None):
"""Add text to plot that describes basic statistics."""
if not self.plots[plot_type]['show_stats']:
return

# Different options for the different plots types
fontsize = 6.0
y_pos = 0.95
if plot_type == 'map':
x_pos_bias = 0.92
x_pos = 0.0
elif plot_type == 'profile':
x_pos_bias = 0.7
x_pos = 0.01
else:
raise NotImplementedError(f"plot_type '{plot_type}' not supported")

# For profile plots add scalar longitude coordinate (necessary for
# calculation of area weights). The exact values for the points/bounds
# of this coordinate do not matter since they don't change the weights.
if not cube.coords('longitude'):
lon_coord = AuxCoord(
180.0,
bounds=[0.0, 360.0],
var_name='lon',
standard_name='longitude',
long_name='longitude',
units='degrees_east',
)
cube.add_aux_coord(lon_coord, ())

# Mean
weights = area_weights(cube)
if ref_cube is None:
mean = cube.collapsed(dim_coords, iris.analysis.MEAN,
weights=weights)
else:
mean = (cube - ref_cube).collapsed(dim_coords, iris.analysis.MEAN,
weights=weights)
axes.text(x_pos, y_pos, f"{mean.data:.2f}{cube.units}",
fontsize=fontsize, transform=axes.transAxes)
if ref_cube is None:
return

# Weighted RMSE
rmse = (cube - ref_cube).collapsed(dim_coords, iris.analysis.RMS,
weights=weights)
axes.text(x_pos_bias, y_pos, f"RMSE={rmse.data:.2f}{cube.units}",
fontsize=fontsize, transform=axes.transAxes)

# Weighted R2
mask = np.ma.getmaskarray(cube.data).ravel()
mask |= np.ma.getmaskarray(ref_cube.data).ravel()
cube_data = cube.data.ravel()[~mask]
ref_cube_data = ref_cube.data.ravel()[~mask]
weights = weights.ravel()[~mask]
r2_val = r2_score(cube_data, ref_cube_data, sample_weight=weights)
axes.text(x_pos_bias, y_pos - 0.1, rf"R$^2$={r2_val:.2f}",
fontsize=fontsize, transform=axes.transAxes)

def _get_custom_mpl_rc_params(self, plot_type):
"""Get custom matplotlib rcParams."""
fontsize = self.plots[plot_type]['fontsize']
Expand Down Expand Up @@ -520,8 +589,8 @@ def _plot_map_with_ref(self, plot_func, dataset, ref_dataset):
# Make sure that the data has the correct dimensions
cube = dataset['cube']
ref_cube = ref_dataset['cube']
self._check_cube_dimensions(cube, plot_type)
self._check_cube_dimensions(ref_cube, plot_type)
dim_coords_dat = self._check_cube_dimensions(cube, plot_type)
dim_coords_ref = self._check_cube_dimensions(ref_cube, plot_type)

# Create single figure with multiple axes
with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)):
Expand All @@ -544,6 +613,7 @@ def _plot_map_with_ref(self, plot_func, dataset, ref_dataset):
if gridline_kwargs is not False:
axes_data.gridlines(**gridline_kwargs)
axes_data.set_title(self._get_label(dataset), pad=3.0)
self._add_stats(plot_type, axes_data, dim_coords_dat, cube)

# Plot reference dataset (top right)
# Note: make sure to use the same vmin and vmax than the top left
Expand All @@ -559,6 +629,7 @@ def _plot_map_with_ref(self, plot_func, dataset, ref_dataset):
if gridline_kwargs is not False:
axes_ref.gridlines(**gridline_kwargs)
axes_ref.set_title(self._get_label(ref_dataset), pad=3.0)
self._add_stats(plot_type, axes_ref, dim_coords_ref, ref_cube)

# Add colorbar(s)
self._add_colorbar(plot_type, plot_data, plot_ref, axes_data,
Expand Down Expand Up @@ -587,6 +658,8 @@ def _plot_map_with_ref(self, plot_func, dataset, ref_dataset):
fontsize=fontsize,
)
cbar_bias.ax.tick_params(labelsize=fontsize)
self._add_stats(plot_type, axes_bias, dim_coords_dat, cube,
ref_cube)

# Customize plot
fig.suptitle(f"{dataset['long_name']} ({dataset['start_year']}-"
Expand All @@ -607,7 +680,7 @@ def _plot_map_without_ref(self, plot_func, dataset):

# Make sure that the data has the correct dimensions
cube = dataset['cube']
self._check_cube_dimensions(cube, plot_type)
dim_coords_dat = self._check_cube_dimensions(cube, plot_type)

# Create plot with desired settings
with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)):
Expand All @@ -621,6 +694,9 @@ def _plot_map_without_ref(self, plot_func, dataset):
if gridline_kwargs is not False:
axes.gridlines(**gridline_kwargs)

# Print statistics if desired
self._add_stats(plot_type, axes, dim_coords_dat, cube)

# Setup colorbar
fontsize = self.plots[plot_type]['fontsize']
colorbar = fig.colorbar(plot_map, ax=axes,
Expand Down Expand Up @@ -650,8 +726,8 @@ def _plot_profile_with_ref(self, plot_func, dataset, ref_dataset):
# Make sure that the data has the correct dimensions
cube = dataset['cube']
ref_cube = ref_dataset['cube']
self._check_cube_dimensions(cube, plot_type)
self._check_cube_dimensions(ref_cube, plot_type)
dim_coords_dat = self._check_cube_dimensions(cube, plot_type)
dim_coords_ref = self._check_cube_dimensions(ref_cube, plot_type)

# Create single figure with multiple axes
with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)):
Expand Down Expand Up @@ -679,6 +755,7 @@ def _plot_profile_with_ref(self, plot_func, dataset, ref_dataset):
FormatStrFormatter('%.1f'))
else:
axes_data.get_yaxis().set_minor_formatter(NullFormatter())
self._add_stats(plot_type, axes_data, dim_coords_dat, cube)

# Plot reference dataset (top right)
# Note: make sure to use the same vmin and vmax than the top left
Expand All @@ -692,6 +769,7 @@ def _plot_profile_with_ref(self, plot_func, dataset, ref_dataset):
plot_ref = plot_func(ref_cube, **plot_kwargs)
axes_ref.set_title(self._get_label(ref_dataset), pad=3.0)
plt.setp(axes_ref.get_yticklabels(), visible=False)
self._add_stats(plot_type, axes_ref, dim_coords_ref, ref_cube)

# Add colorbar(s)
self._add_colorbar(plot_type, plot_data, plot_ref, axes_data,
Expand Down Expand Up @@ -719,6 +797,8 @@ def _plot_profile_with_ref(self, plot_func, dataset, ref_dataset):
fontsize=fontsize,
)
cbar_bias.ax.tick_params(labelsize=fontsize)
self._add_stats(plot_type, axes_bias, dim_coords_dat, cube,
ref_cube)

# Customize plot
fig.suptitle(f"{dataset['long_name']} ({dataset['start_year']}-"
Expand All @@ -739,7 +819,7 @@ def _plot_profile_without_ref(self, plot_func, dataset):

# Make sure that the data has the correct dimensions
cube = dataset['cube']
self._check_cube_dimensions(cube, plot_type)
dim_coords_dat = self._check_cube_dimensions(cube, plot_type)

# Create plot with desired settings
with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)):
Expand All @@ -749,6 +829,9 @@ def _plot_profile_without_ref(self, plot_func, dataset):
plot_kwargs['axes'] = axes
plot_profile = plot_func(cube, **plot_kwargs)

# Print statistics if desired
self._add_stats(plot_type, axes, dim_coords_dat, cube)

# Setup colorbar
fontsize = self.plots[plot_type]['fontsize']
colorbar = fig.colorbar(plot_profile, ax=axes,
Expand Down Expand Up @@ -808,7 +891,7 @@ def _check_cube_dimensions(cube, plot_type):
for dims in expected_dimensions:
cube_dims = [cube.coords(dim, dim_coords=True) for dim in dims]
if all(cube_dims) and cube.ndim == len(dims):
return
return dims
expected_dims_str = ' or '.join(
[str(dims) for dims in expected_dimensions]
)
Expand Down Expand Up @@ -968,6 +1051,12 @@ def create_map_plot(self, datasets, short_name):
)
ancestors.append(ref_dataset['filename'])

# If statistics are shown add a brief description to the caption
if self.plots[plot_type]['show_stats']:
caption += (
" The number in the top left corner corresponds to the "
"spatial mean (weighted by grid cell areas).")

# Save plot
plt.savefig(plot_path, **self.cfg['savefig_kwargs'])
logger.info("Wrote %s", plot_path)
Expand Down Expand Up @@ -1028,6 +1117,12 @@ def create_profile_plot(self, datasets, short_name):
)
ancestors.append(ref_dataset['filename'])

# If statistics are shown add a brief description to the caption
if self.plots[plot_type]['show_stats']:
caption += (
" The number in the top left corner corresponds to the "
"spatial mean (weighted by grid cell areas).")

# Save plot
plt.savefig(plot_path, **self.cfg['savefig_kwargs'])
logger.info("Wrote %s", plot_path)
Expand Down

0 comments on commit 752c3f6

Please sign in to comment.