Skip to content

Commit

Permalink
Add test that checks the integral (w.r.t. threshold) of DebiasedEnsem…
Browse files Browse the repository at this point in the history
…bleBrierScore

= CRPS (on average).

PiperOrigin-RevId: 719467634
  • Loading branch information
langmore authored and Weatherbench2 authors committed Jan 25, 2025
1 parent 2f849ab commit 3af79e3
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@


def get_random_truth_and_forecast(
variables=('geopotential',), ensemble_size=None, seed=802701, **data_kwargs
variables=('geopotential',),
ensemble_size=None,
seed=802701,
lead_start='0 day',
lead_stop='10 day',
**data_kwargs,
):
"""Makes the tuple (truth, forecast) from kwargs."""
data_kwargs_to_use = dict(
Expand All @@ -43,7 +48,10 @@ def get_random_truth_and_forecast(
)
forecast = utils.random_like(
schema.mock_forecast_data(
ensemble_size=ensemble_size, **data_kwargs_to_use
ensemble_size=ensemble_size,
lead_start=lead_start,
lead_stop=lead_stop,
**data_kwargs_to_use,
),
seed=seed + 1,
)
Expand Down Expand Up @@ -1147,6 +1155,90 @@ def test_versus_large_ensemble_and_ensure_skipna_works(self):
atol=4 * stderr,
)

def test_integral_of_brier_score_is_crps(self):
# The integral over threshold of debiased brier score is unbiased CRPS.
truth, forecast = get_random_truth_and_forecast(
ensemble_size=2,
spatial_resolution_in_degrees=60,
# Don't need many samples...the finite sample Debiased BS integral over
# all thresholds is exactly equal to the finite sample CRPS.
time_start='2019-01-01',
time_stop='2019-01-04',
time_resolution='12 hours',
lead_start='0 day',
lead_stop='0 day',
levels=[500, 700, 850],
)

# Make forecasts (i) different mean/variance than truth, and (ii) depend on
# level.
forecast = (
forecast
+ np.abs(forecast) ** 0.2
+ xr.DataArray(
[-1, 0, 1], dims=['level'], coords={'level': forecast.level.data}
)
)

# climatology has the same stats as Normal(0, 1), which is the same
# distribution as "truth". This ensures the clima. quantiles are reasonable.
climatology_mean = xr.zeros_like(
truth.isel(time=0, drop=True).expand_dims(dayofyear=366)
)
climatology_std = xr.ones_like(
truth.isel(time=0, drop=True)
.expand_dims(
dayofyear=366,
)
.rename({'geopotential': 'geopotential_std'})
)
climatology = xr.merge([climatology_mean, climatology_std])
n_quantiles = 200
quantiles = np.linspace(0, 1, num=n_quantiles + 2)[1:-1]
threshold_objects = [
thresholds.GaussianQuantileThreshold(
climatology=climatology, quantile=q
)
for q in quantiles
]
bs = metrics.DebiasedEnsembleBrierScore(threshold_objects).compute(
forecast, truth
)['geopotential']

# Now compute the integral of BS, with respect to the threshold.
# First build a DataArray of thresholds, corresponding to the quantiles.
precip_thresholds = []
for q, thresh in zip(quantiles, threshold_objects):
t = thresh.compute(truth)['geopotential']
# To simplify integration, we ensured threshold depends only on level.
# This "assert_array_less" checks that we did this correctly.
np.testing.assert_array_less(
t.std(['time', 'longitude', 'latitude']), 1e-4
)
precip_thresholds.append(
t.isel(time=0, longitude=0, latitude=0, drop=True).expand_dims(
quantile=[q]
)
)
precip_thresholds = xr.concat(precip_thresholds, dim='quantile')

# Second, do the BS integral, one level at a time.
bs = bs.assign_coords(threshold=precip_thresholds)
integrals = []
for level in bs.level:
integrals.append(bs.sel(level=level).integrate('threshold'))
bs_integral = xr.concat(integrals, dim='level')

crps = metrics.CRPS().compute(forecast, truth)['geopotential']

# Tolerance is due to integration error only.
# Integation error is going to be tiny, due to using 200 points to
# interpolate a function that we know is bounded to ≈ [-5, 5].
# The integrand involves indicator functions, so it is not smooth. So we
# only get O(1 / n_quantiles) error bounds despite integration being
# Trapezoidal.
xr.testing.assert_allclose(bs_integral, crps, rtol=10 / n_quantiles)


class EnsembleIgnoranceScoreTest(parameterized.TestCase):

Expand Down

0 comments on commit 3af79e3

Please sign in to comment.