Skip to content

Commit

Permalink
Unify dtype handling of preprocessors (#2393)
Browse files Browse the repository at this point in the history
  • Loading branch information
schlunma authored Apr 25, 2024
1 parent 1e49077 commit f5ac7fb
Show file tree
Hide file tree
Showing 16 changed files with 174 additions and 45 deletions.
5 changes: 5 additions & 0 deletions doc/api/esmvalcore.preprocessor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
Preprocessor functions
======================

All preprocessor functions are designed to preserve floating point data types.
For example, input data of type ``float32`` will give output of ``float32``.
However, other data types may change, e.g., data of type ``int`` may give
output of type ``float64``.

.. autodata:: esmvalcore.preprocessor.DEFAULT_ORDER
.. automodule:: esmvalcore.preprocessor
10 changes: 8 additions & 2 deletions esmvalcore/_recipe/check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Module with functions to check a recipe."""
from __future__ import annotations

import inspect
import logging
import os
import subprocess
from inspect import getfullargspec
from pprint import pformat
from shutil import which
from typing import Any, Iterable
Expand Down Expand Up @@ -470,7 +470,13 @@ def _check_regular_stat(step, step_settings):
return

# Ignore other preprocessor arguments, e.g., 'hours' for hourly_statistics
other_args = getfullargspec(preproc_func).args[1:]
other_args = [
n for (n, p) in inspect.signature(preproc_func).parameters.items() if
p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
][1:]
operator_kwargs = {
k: v for (k, v) in step_settings.items() if k not in other_args
}
Expand Down
41 changes: 31 additions & 10 deletions esmvalcore/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
def _get_itype(step):
"""Get the input type of a preprocessor function."""
function = globals()[step]
itype = inspect.getfullargspec(function).args[0]
itype = list(inspect.signature(function).parameters)[0]
return itype


Expand All @@ -238,31 +238,52 @@ def check_preprocessor_settings(settings):
f"{', '.join(DEFAULT_ORDER)}"
)

function = function = globals()[step]
argspec = inspect.getfullargspec(function)
args = argspec.args[1:]
if not (argspec.varargs or argspec.varkw):
# Check for invalid arguments
function = globals()[step]

# Note: below, we do not use inspect.getfullargspec since this does not
# work with decorated functions. On the other hand, inspect.signature
# behaves correctly with properly decorated functions (those that use
# functools.wraps).
signature = inspect.signature(function)
args = [
n for (n, p) in signature.parameters.items() if
p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
][1:]

# Check for invalid arguments (only possible if no *args or **kwargs
# allowed)
var_kinds = [p.kind for p in signature.parameters.values()]
check_args = not any([
inspect.Parameter.VAR_POSITIONAL in var_kinds,
inspect.Parameter.VAR_KEYWORD in var_kinds,
])
if check_args:
invalid_args = set(settings[step]) - set(args)
if invalid_args:
raise ValueError(
f"Invalid argument(s): {', '.join(invalid_args)} "
f"Invalid argument(s) [{', '.join(invalid_args)}] "
f"encountered for preprocessor function {step}. \n"
f"Valid arguments are: [{', '.join(args)}]"
)

# Check for missing arguments
defaults = argspec.defaults
end = None if defaults is None else -len(defaults)
defaults = [
p.default for p in signature.parameters.values()
if p.default is not inspect.Parameter.empty
]
end = None if not defaults else -len(defaults)
missing_args = set(args[:end]) - set(settings[step])
if missing_args:
raise ValueError(
f"Missing required argument(s) {missing_args} for "
f"preprocessor function {step}"
)

# Final sanity check in case the above fails to catch a mistake
try:
signature = inspect.Signature.from_callable(function)
signature.bind(None, **settings[step])
except TypeError:
logger.error(
Expand Down
17 changes: 4 additions & 13 deletions esmvalcore/preprocessor/_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_iris_aggregator,
get_normalized_cube,
guess_bounds,
preserve_float_dtype,
update_weights_kwargs,
)
from esmvalcore.preprocessor._supplementary_vars import (
Expand Down Expand Up @@ -190,6 +191,7 @@ def _extract_irregular_region(cube, start_longitude, end_longitude,
return cube


@preserve_float_dtype
def zonal_statistics(
cube: Cube,
operator: str,
Expand Down Expand Up @@ -236,10 +238,10 @@ def zonal_statistics(
result = cube.collapsed('longitude', agg, **agg_kwargs)
if normalize is not None:
result = get_normalized_cube(cube, result, normalize)
result.data = result.core_data().astype(np.float32, casting='same_kind')
return result


@preserve_float_dtype
def meridional_statistics(
cube: Cube,
operator: str,
Expand Down Expand Up @@ -285,7 +287,6 @@ def meridional_statistics(
result = cube.collapsed('latitude', agg, **agg_kwargs)
if normalize is not None:
result = get_normalized_cube(cube, result, normalize)
result.data = result.core_data().astype(np.float32, casting='same_kind')
return result


Expand Down Expand Up @@ -404,6 +405,7 @@ def _try_adding_calculated_cell_area(cube: Cube) -> None:
variables=['areacella', 'areacello'],
required='prefer_at_least_one',
)
@preserve_float_dtype
def area_statistics(
cube: Cube,
operator: str,
Expand Down Expand Up @@ -451,7 +453,6 @@ def area_statistics(
`cell_area` is not available.
"""
original_dtype = cube.dtype
has_cell_measure = bool(cube.cell_measures('cell_area'))

# Get aggregator and correct kwargs (incl. weights)
Expand All @@ -464,16 +465,6 @@ def area_statistics(
if normalize is not None:
result = get_normalized_cube(cube, result, normalize)

# Make sure to preserve dtype
new_dtype = result.dtype
if original_dtype != new_dtype:
logger.debug(
"area_statistics changed dtype from %s to %s, changing back",
original_dtype,
new_dtype,
)
result.data = result.core_data().astype(original_dtype)

# Make sure input cube has not been modified
if not has_cell_measure and cube.cell_measures('cell_area'):
cube.remove_cell_measure('cell_area')
Expand Down
2 changes: 1 addition & 1 deletion esmvalcore/preprocessor/_derive/sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def calculate(cubes):
"""
mrsos_cube = cubes.extract_cube(NameConstraint(var_name='mrsos'))

depth = mrsos_cube.coord('depth').core_bounds().astype(np.float32)
depth = mrsos_cube.coord('depth').core_bounds().astype(np.float64)
layer_thickness = depth[..., 1] - depth[..., 0]

sm_cube = mrsos_cube / layer_thickness / 998.2
Expand Down
4 changes: 4 additions & 0 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from esmvalcore.exceptions import ESMValCoreDeprecationWarning
from esmvalcore.iris_helpers import has_irregular_grid, has_unstructured_grid
from esmvalcore.preprocessor._other import get_array_module
from esmvalcore.preprocessor._shared import preserve_float_dtype
from esmvalcore.preprocessor._supplementary_vars import (
add_ancillary_variable,
add_cell_measure,
Expand Down Expand Up @@ -707,6 +708,7 @@ def _get_name_and_shape_key(
return (name, *shapes)


@preserve_float_dtype
def regrid(
cube: Cube,
target_grid: Cube | Dataset | Path | str | dict,
Expand Down Expand Up @@ -1205,6 +1207,7 @@ def _rechunk_aux_factory_dependencies(
return cube


@preserve_float_dtype
def extract_levels(
cube: iris.cube.Cube,
levels: np.typing.ArrayLike | da.Array,
Expand Down Expand Up @@ -1395,6 +1398,7 @@ def get_reference_levels(dataset):
return coord.points.tolist()


@preserve_float_dtype
def extract_coordinate_points(cube, definition, scheme):
"""Extract points from any coordinate with interpolation.
Expand Down
3 changes: 2 additions & 1 deletion esmvalcore/preprocessor/_rolling_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from iris.cube import Cube

from ._shared import get_iris_aggregator
from ._shared import get_iris_aggregator, preserve_float_dtype

logger = logging.getLogger(__name__)


@preserve_float_dtype
def rolling_window_statistics(
cube: Cube,
coordinate: str,
Expand Down
28 changes: 28 additions & 0 deletions esmvalcore/preprocessor/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
Utility functions that can be used for multiple preprocessor steps
"""
from __future__ import annotations

import logging
import re
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any, Literal, Optional

import dask.array as da
Expand All @@ -16,6 +19,7 @@
from iris.cube import Cube

from esmvalcore.exceptions import ESMValCoreDeprecationWarning
from esmvalcore.typing import DataType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,3 +242,27 @@ def get_normalized_cube(
normalized_cube.units = new_units

return normalized_cube


def preserve_float_dtype(func: Callable) -> Callable:
"""Preserve object's float dtype (all other dtypes are allowed to change).
This can be used as a decorator for preprocessor functions to ensure that
floating dtypes are preserved. For example, input of type float32 will
always give output of type float32, but input of type int will be allowed
to give output with any type.
"""

@wraps(func)
def wrapper(data: DataType, *args: Any, **kwargs: Any) -> DataType:
dtype = data.dtype
result = func(data, *args, **kwargs)
if np.issubdtype(dtype, np.floating) and result.dtype != dtype:
if isinstance(result, Cube):
result.data = result.core_data().astype(dtype)
else:
result = result.astype(dtype)
return result

return wrapper
20 changes: 11 additions & 9 deletions esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from esmvalcore.iris_helpers import date2num, rechunk_cube
from esmvalcore.preprocessor._shared import (
get_iris_aggregator,
preserve_float_dtype,
update_weights_kwargs,
)

Expand Down Expand Up @@ -446,6 +447,7 @@ def _aggregate_time_fx(result_cube, source_cube):
ancillary_dims)


@preserve_float_dtype
def hourly_statistics(
cube: Cube,
hours: int,
Expand Down Expand Up @@ -501,6 +503,7 @@ def hourly_statistics(
return result


@preserve_float_dtype
def daily_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -541,6 +544,7 @@ def daily_statistics(
return result


@preserve_float_dtype
def monthly_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -579,6 +583,7 @@ def monthly_statistics(
return result


@preserve_float_dtype
def seasonal_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -673,6 +678,7 @@ def spans_full_season(cube: Cube) -> list[bool]:
return result


@preserve_float_dtype
def annual_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -714,6 +720,7 @@ def annual_statistics(
return result


@preserve_float_dtype
def decadal_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -762,6 +769,7 @@ def get_decade(coord, value):
return result


@preserve_float_dtype
def climate_statistics(
cube: Cube,
operator: str = 'mean',
Expand Down Expand Up @@ -804,7 +812,6 @@ def climate_statistics(
iris.cube.Cube
Climate statistics cube.
"""
original_dtype = cube.dtype
period = period.lower()

# Use Cube.collapsed when full period is requested
Expand Down Expand Up @@ -846,14 +853,6 @@ def climate_statistics(
clim_cube.slices_over(clim_coord.name())).merge_cube()
cube.remove_coord(clim_coord)

# Make sure that original dtype is preserved
new_dtype = clim_cube.dtype
if original_dtype != new_dtype:
logger.debug(
"climate_statistics changed dtype from "
"%s to %s, changing back", original_dtype, new_dtype)
clim_cube.data = clim_cube.core_data().astype(original_dtype)

return clim_cube


Expand All @@ -867,6 +866,7 @@ def _add_time_weights_coord(cube):
cube.add_aux_coord(time_weights_coord, cube.coord_dims('time'))


@preserve_float_dtype
def anomalies(
cube: Cube,
period: str,
Expand Down Expand Up @@ -1104,6 +1104,7 @@ def low_pass_weights(window, cutoff):
return weights[1:-1]


@preserve_float_dtype
def timeseries_filter(
cube: Cube,
window: int,
Expand Down Expand Up @@ -1658,6 +1659,7 @@ def _check_cube_coords(cube):
)


@preserve_float_dtype
def local_solar_time(cube: Cube) -> Cube:
"""Convert UTC time coordinate to local solar time (LST).
Expand Down
Loading

0 comments on commit f5ac7fb

Please sign in to comment.