Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous lazy preprocessor improvements #2520

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions esmvalcore/preprocessor/_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from iris.exceptions import CoordinateNotFoundError

from esmvalcore.preprocessor._shared import (
apply_mask,
get_dims_along_axes,
get_iris_aggregator,
get_normalized_cube,
preserve_float_dtype,
Expand Down Expand Up @@ -188,8 +190,8 @@ def _extract_irregular_region(
cube = cube[..., i_slice, j_slice]
selection = selection[i_slice, j_slice]
# Mask remaining coordinates outside region
mask = da.broadcast_to(~selection, cube.shape)
cube.data = da.ma.masked_where(mask, cube.core_data())
horizontal_dims = get_dims_along_axes(cube, ["X", "Y"])
cube.data = apply_mask(~selection, cube.core_data(), horizontal_dims)
return cube


Expand Down Expand Up @@ -857,31 +859,45 @@ def _mask_cube(cube: Cube, masks: dict[str, np.ndarray]) -> Cube:
_cube.add_aux_coord(
AuxCoord(id_, units="no_unit", long_name="shape_id")
)
mask = da.broadcast_to(mask, _cube.shape)
_cube.data = da.ma.masked_where(~mask, _cube.core_data())
horizontal_dims = get_dims_along_axes(cube, axes=["X", "Y"])
_cube.data = apply_mask(~mask, _cube.core_data(), horizontal_dims)
cubelist.append(_cube)
result = fix_coordinate_ordering(cubelist.merge_cube())
if cube.cell_measures():
for measure in cube.cell_measures():
# Cell measures that are time-dependent, with 4 dimension and
# an original shape of (time, depth, lat, lon), need to be
# broadcasted to the cube with 5 dimensions and shape
# (time, shape_id, depth, lat, lon)
if measure.ndim > 3 and result.ndim > 4:
data = measure.core_data()
data = da.expand_dims(data, axis=(1,))
data = da.broadcast_to(data, result.shape)
measure = iris.coords.CellMeasure(
for measure in cube.cell_measures():
# Cell measures that are time-dependent, with 4 dimension and
# an original shape of (time, depth, lat, lon), need to be
# broadcasted to the cube with 5 dimensions and shape
# (time, shape_id, depth, lat, lon)
if measure.ndim > 3 and result.ndim > 4:
data = measure.core_data()
if result.has_lazy_data():
# Make the cell measure lazy if the result is lazy.
cube_chunks = cube.lazy_data().chunks
chunk_dims = cube.cell_measure_dims(measure)
data = da.asarray(
data,
standard_name=measure.standard_name,
long_name=measure.long_name,
units=measure.units,
measure=measure.measure,
var_name=measure.var_name,
attributes=measure.attributes,
chunks=tuple(cube_chunks[i] for i in chunk_dims),
)
add_cell_measure(result, measure, measure.measure)
if cube.ancillary_variables():
for ancillary_variable in cube.ancillary_variables():
add_ancillary_variable(result, ancillary_variable)
chunks = result.lazy_data().chunks
else:
chunks = None
dim_map = get_dims_along_axes(result, ["T", "Z", "Y", "X"])
data = iris.util.broadcast_to_shape(
data,
result.shape,
dim_map=dim_map,
chunks=chunks,
)
measure = iris.coords.CellMeasure(
data,
standard_name=measure.standard_name,
long_name=measure.long_name,
units=measure.units,
measure=measure.measure,
var_name=measure.var_name,
attributes=measure.attributes,
)
add_cell_measure(result, measure, measure.measure)
for ancillary_variable in cube.ancillary_variables():
add_ancillary_variable(result, ancillary_variable)
return result
2 changes: 2 additions & 0 deletions esmvalcore/preprocessor/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from esmvalcore.cmor.check import CheckLevels
from esmvalcore.esgf.facets import FACETS
from esmvalcore.iris_helpers import merge_cube_attributes
from esmvalcore.preprocessor._shared import _rechunk_aux_factory_dependencies

from .._task import write_ncl_settings

Expand Down Expand Up @@ -392,6 +393,7 @@ def concatenate(cubes, check_level=CheckLevels.DEFAULT):
cubes = _sort_cubes_by_time(cubes)
_fix_calendars(cubes)
cubes = _check_time_overlaps(cubes)
cubes = [_rechunk_aux_factory_dependencies(cube) for cube in cubes]
result = _concatenate_cubes(cubes, check_level=check_level)

if len(result) == 1:
Expand Down
32 changes: 5 additions & 27 deletions esmvalcore/preprocessor/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import logging
import os
from collections.abc import Iterable
from typing import Literal, Optional
from typing import Literal

import cartopy.io.shapereader as shpreader
import dask.array as da
Expand All @@ -22,7 +21,7 @@
from iris.cube import Cube
from iris.util import rolling_window

from esmvalcore.preprocessor._shared import get_array_module
from esmvalcore.preprocessor._shared import apply_mask

from ._supplementary_vars import register_supplementaries

Expand Down Expand Up @@ -61,24 +60,6 @@ def _get_fx_mask(
return inmask


def _apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Optional[Iterable[int]] = None,
) -> np.ndarray | da.Array:
"""Apply a (broadcasted) mask on an array."""
npx = get_array_module(mask, array)
if dim_map is not None:
if isinstance(array, da.Array):
chunks = array.chunks
else:
chunks = None
mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map, chunks=chunks
)
return npx.ma.masked_where(mask, array)


@register_supplementaries(
variables=["sftlf", "sftof"],
required="prefer_at_least_one",
Expand Down Expand Up @@ -145,7 +126,7 @@ def mask_landsea(cube: Cube, mask_out: Literal["land", "sea"]) -> Cube:
landsea_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landsea_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -212,7 +193,7 @@ def mask_landseaice(cube: Cube, mask_out: Literal["landsea", "ice"]) -> Cube:
landseaice_mask = _get_fx_mask(
ancillary_var.core_data(), mask_out, ancillary_var.var_name
)
cube.data = _apply_mask(
cube.data = apply_mask(
landseaice_mask,
cube.core_data(),
cube.ancillary_variable_dims(ancillary_var),
Expand Down Expand Up @@ -350,10 +331,7 @@ def _mask_with_shp(cube, shapefilename, region_indices=None):
else:
mask |= shp_vect.contains(region, x_p_180, y_p_90)

if cube.has_lazy_data():
mask = da.array(mask)

cube.data = _apply_mask(
cube.data = apply_mask(
mask,
cube.core_data(),
cube.coord_dims("latitude") + cube.coord_dims("longitude"),
Expand Down
31 changes: 1 addition & 30 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from esmvalcore.exceptions import ESMValCoreDeprecationWarning
from esmvalcore.iris_helpers import has_irregular_grid, has_unstructured_grid
from esmvalcore.preprocessor._shared import (
_rechunk_aux_factory_dependencies,
get_array_module,
get_dims_along_axes,
preserve_float_dtype,
Expand Down Expand Up @@ -1174,36 +1175,6 @@ def parse_vertical_scheme(scheme):
return scheme, extrap_scheme


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.

This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
try:
factory = cube.aux_factory(coord_name)
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(chunks + (None,))
cube.replace_coord(coord)
return cube


@preserve_float_dtype
def extract_levels(
cube: iris.cube.Cube,
Expand Down
82 changes: 82 additions & 0 deletions esmvalcore/preprocessor/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,85 @@ def get_dims_along_coords(
"""Get a tuple with the dimensions along one or more coordinates."""
dims = {d for coord in coords for d in _get_dims_along(cube, coord)}
return tuple(sorted(dims))


def apply_mask(
mask: np.ndarray | da.Array,
array: np.ndarray | da.Array,
dim_map: Iterable[int],
) -> np.ma.MaskedArray | da.Array:
"""Apply a (broadcasted) mask on an array.

Parameters
----------
mask:
The mask to apply to array.
array:
The array to mask out.
dim_map : :class:`list`, :class:`tuple` etc
A mapping of the dimensions of *mask* to their corresponding
dimension in *array*. *dim_map* must be the same length as the
number of dimensions in *mask*. Each element of *dim_map*
corresponds to a dimension of *mask* and its value provides
the index in *array* which the dimension of *mask* corresponds
to, so the first element of *dim_map* gives the index of *array*
that corresponds to the first dimension of *mask* etc.

Returns
-------
np.ma.MaskedArray or da.Array:
A copy of the input array with the mask applied.

"""
if isinstance(array, da.Array):
array_chunks = array.chunks
# If the mask is not a Dask array yet, we make it into a Dask array
# before broadcasting to avoid inserting a large array into the Dask
# graph.
mask_chunks = tuple(array_chunks[i] for i in dim_map)
mask = da.asarray(mask, chunks=mask_chunks)
else:
array_chunks = None

mask = iris.util.broadcast_to_shape(
mask, array.shape, dim_map=dim_map, chunks=array_chunks
)

array_module = get_array_module(mask, array)
return array_module.ma.masked_where(mask, array)


def _rechunk_aux_factory_dependencies(
cube: iris.cube.Cube,
coord_name: str | None = None,
) -> iris.cube.Cube:
"""Rechunk coordinate aux factory dependencies.

This ensures that the resulting coordinate has reasonably sized
chunks that are aligned with the cube data for optimal computational
performance.
"""
# Workaround for https://github.com/SciTools/iris/issues/5457
if coord_name is None:
factories = cube.aux_factories
else:
try:
factories = [cube.aux_factory(coord_name)]
except iris.exceptions.CoordinateNotFoundError:
return cube

cube = cube.copy()
cube_chunks = cube.lazy_data().chunks
for factory in factories:
for coord in factory.dependencies.values():
coord_dims = cube.coord_dims(coord)
if coord_dims:
coord = coord.copy()
chunks = tuple(cube_chunks[i] for i in coord_dims)
coord.points = coord.lazy_points().rechunk(chunks)
if coord.has_bounds():
coord.bounds = coord.lazy_bounds().rechunk(
chunks + (None,)
)
cube.replace_coord(coord)
return cube
4 changes: 4 additions & 0 deletions esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,10 @@ def timeseries_filter(
# Apply filter
(agg, agg_kwargs) = get_iris_aggregator(filter_stats, **operator_kwargs)
agg_kwargs["weights"] = wgts
if cube.has_lazy_data():
# Ensure the cube data chunktype is np.MaskedArray so rolling_window
# does not ignore a potential mask.
cube.data = da.ma.masked_array(cube.core_data())
cube = cube.rolling_window("time", agg, len(wgts), **agg_kwargs)

return cube
Expand Down
Loading