Skip to content

Commit

Permalink
Allowed usage of multi_model_statistics on single cubes/products (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
schlunma authored Jan 13, 2023
1 parent fe2e8cf commit 80dc689
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 17 deletions.
15 changes: 11 additions & 4 deletions esmvalcore/preprocessor/_multimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from iris.cube import Cube, CubeList
from iris.exceptions import MergeError
from iris.util import equalise_attributes
from iris.util import equalise_attributes, new_axis

from esmvalcore.iris_helpers import date2num
from esmvalcore.preprocessor import remove_fx_variables
Expand Down Expand Up @@ -302,6 +302,12 @@ def _combine(cubes):

cubes = CubeList(cubes)

# For a single cube, merging returns a scalar CONCAT_DIM, which leads to a
# "Cannot collapse a dimension which does not describe any data" error when
# collapsing. Thus, treat single cubes differently here.
if len(cubes) == 1:
return new_axis(cubes[0], scalar_coord=CONCAT_DIM)

try:
merged_cube = cubes.merge_cube()
except MergeError as exc:
Expand Down Expand Up @@ -411,9 +417,10 @@ def _multicube_statistics(cubes, statistics, span):
Cubes are merged and subsequently collapsed along a new auxiliary
coordinate. Inconsistent attributes will be removed.
"""
if len(cubes) == 1:
raise ValueError('Cannot perform multicube statistics '
'for a single cube.')
if not cubes:
raise ValueError(
"Cannot perform multicube statistics for an empty list of cubes"
)

# Avoid modifying inputs
copied_cubes = [cube.copy() for cube in cubes]
Expand Down
110 changes: 97 additions & 13 deletions tests/unit/preprocessor/_multimodel/test_multimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,19 +669,6 @@ def test_edge_case_sub_daily_data_fail(span):
_ = multi_model_statistics(cubes, span, statistics)


@pytest.mark.parametrize('span', SPAN_OPTIONS)
def test_edge_case_single_cube_fail(span):
"""Test that an error is raised when a single cube is passed."""
cube = generate_cube_from_dates('monthly')
cubes = (cube, )

statistic = 'min'
statistics = (statistic, )

with pytest.raises(ValueError):
_ = multi_model_statistics(cubes, span, statistics)


def test_unify_time_coordinates():
"""Test set common calendar."""
cube1 = generate_cube_from_dates('monthly',
Expand Down Expand Up @@ -1045,3 +1032,100 @@ def test_arbitrary_dims_0d(cubes_with_arbitrary_dimensions):
stat_cube = stat_cubes['sum']
assert stat_cube.shape == ()
assert_array_allclose(stat_cube.data, np.ma.array(0.0))


def test_empty_input_multi_model_statistics():
"""Check that ``multi_model_statistics`` fails with empty input."""
msg = "Cannot perform multicube statistics for an empty list of cubes"
with pytest.raises(ValueError, match=msg):
mm.multi_model_statistics([], span='full', statistics=['mean'])


def test_empty_input_ensemble_statistics():
"""Check that ``ensemble_statistics`` fails with empty input."""
msg = "Cannot perform multicube statistics for an empty list of cubes"
with pytest.raises(ValueError, match=msg):
mm.ensemble_statistics(
[], span='full', statistics=['mean'], output_products=[]
)


STATS = ['mean', 'median', 'min', 'max', 'p42.314', 'std_dev']


@pytest.mark.parametrize('stat', STATS)
@pytest.mark.parametrize(
'products',
[
CubeList([generate_cube_from_dates('monthly')]),
set([PreprocessorFile(generate_cube_from_dates('monthly'))]),
],
)
def test_single_input_multi_model_statistics(products, stat):
"""Check that ``multi_model_statistics`` works with a single cube."""
output = PreprocessorFile()
output_products = {'': {stat: output}}
kwargs = {
'statistics': [stat],
'span': 'full',
'output_products': output_products,
'keep_input_datasets': False,
}

results = mm.multi_model_statistics(products, **kwargs)

assert len(results) == 1

if isinstance(results, dict): # for cube as input
cube = results[stat]
else: # for PreprocessorFile as input
result = next(iter(results))
assert len(result.cubes) == 1
cube = result.cubes[0]

if stat == 'std_dev':
assert_array_allclose(
cube.data, np.ma.masked_invalid([np.nan, np.nan, np.nan])
)
else:
assert_array_allclose(cube.data, np.ma.array([1.0, 1.0, 1.0]))


@pytest.mark.parametrize('stat', STATS)
@pytest.mark.parametrize(
'products',
[
CubeList([generate_cube_from_dates('monthly')]),
{PreprocessorFile(generate_cube_from_dates('monthly'))},
],
)
def test_single_input_ensemble_statistics(products, stat):
"""Check that ``ensemble_statistics`` works with a single cube."""
cube = generate_cube_from_dates('monthly')
attributes = {
'project': 'project',
'dataset': 'dataset',
'exp': 'exp',
'ensemble': '1',
}
products = {PreprocessorFile(cube, attributes=attributes)}
output = PreprocessorFile()
output_products = {'project_dataset_exp': {stat: output}}
kwargs = {
'statistics': [stat],
'output_products': output_products,
}

results = mm.ensemble_statistics(products, **kwargs)

assert len(results) == 1
result = next(iter(results))
assert len(result.cubes) == 1
cube = result.cubes[0]

if stat == 'std_dev':
assert_array_allclose(
cube.data, np.ma.masked_invalid([np.nan, np.nan, np.nan])
)
else:
assert_array_allclose(cube.data, np.ma.array([1.0, 1.0, 1.0]))

0 comments on commit 80dc689

Please sign in to comment.