Skip to content

Commit

Permalink
fix: histogram only on non-nan values (#15)
Browse files Browse the repository at this point in the history
* fix: histogram only in non-nan values
* fix: enforce the same binning for histograms

reviewers: @anaprietonem, @gmertes
  • Loading branch information
sahahner authored Aug 19, 2024
1 parent c23a8c8 commit 0436daf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.1.0...HEAD)

### Added

#### Functionality

- Enable the callback for plotting a histogram for variables containing NaNs
- Enforce same binning for histograms comparing true data to predicted data

## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/compare/x.x.x...0.1.0) - 2024-08-16

### Added
Expand Down
26 changes: 18 additions & 8 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def plot_power_spectrum(
ax[plot_idx].loglog(
np.arange(1, amplitude_t.shape[0]),
amplitude_t[1 : (amplitude_t.shape[0])],
label="Truth (ERA5)",
label="Truth (data)",
)
ax[plot_idx].loglog(
np.arange(1, amplitude_p.shape[0]),
Expand Down Expand Up @@ -279,24 +279,34 @@ def plot_histogram(
for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
yt = y_true[..., variable_idx].squeeze()
yp = y_pred[..., variable_idx].squeeze()
# postprocessed outputs so we need to handle possible NaNs

# Calculate the histogram
# Calculate the histogram and handle NaNs
if output_only:
# histogram of true increment and predicted increment
xt = x[..., variable_idx].squeeze() * int(output_only)
hist_yt, bins_yt = np.histogram((yt - xt), bins=100)
hist_yp, bins_yp = np.histogram((yp - xt), bins=100)
yt_xt = yt - xt
yp_xt = yp - xt
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max])
else:
hist_yt, bins_yt = np.histogram(yt, bins=100)
hist_yp, bins_yp = np.histogram(yp, bins=100)
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max])

# Visualization trick for tp
if variable_name in {"tp", "cp"}:
# in-place multiplication does not work here because variables are different numpy types
hist_yt = hist_yt * bins_yt[:-1]
hist_yp = hist_yp * bins_yp[:-1]
# Plot the modified histogram
ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (ERA5)")
ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="Anemoi")
ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (data)")
ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="Predicted")

ax[plot_idx].set_title(variable_name)
ax[plot_idx].set_xlabel(variable_name)
Expand Down

0 comments on commit 0436daf

Please sign in to comment.