From ebd456c47351442571b19288c2cb6f23b1c9d0ed Mon Sep 17 00:00:00 2001 From: ValentinaHutter <85164505+ValentinaHutter@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:19:12 +0200 Subject: [PATCH] add count process (#273) * add count process * add count process * add count process * update tests --- .../process_implementations/arrays.py | 22 +++++++++- tests/test_arrays.py | 41 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/openeo_processes_dask/process_implementations/arrays.py b/openeo_processes_dask/process_implementations/arrays.py index ebbaeaf8..e7208ed7 100644 --- a/openeo_processes_dask/process_implementations/arrays.py +++ b/openeo_processes_dask/process_implementations/arrays.py @@ -1,6 +1,7 @@ +import copy import itertools import logging -from typing import Any, Optional +from typing import Any, Callable, Optional, Union import dask.array as da import numpy as np @@ -10,6 +11,7 @@ from openeo_pg_parser_networkx.pg_schema import DateTime from xarray.core.duck_array_ops import isnull, notnull +from openeo_processes_dask.process_implementations.comparison import is_valid from openeo_processes_dask.process_implementations.cubes.utils import _is_dask_array from openeo_processes_dask.process_implementations.exceptions import ( ArrayElementNotAvailable, @@ -35,6 +37,7 @@ "order", "rearrange", "sort", + "count", ] @@ -337,3 +340,20 @@ def sort( return data_sorted_flip elif nodata == True: # default sort behaviour, np.nan values are put last return data_sorted + + +def count( + data: ArrayLike, + condition: Optional[Union[Callable, bool]] = None, + context: Any = None, + axis=None, + keepdims=False, +): + if condition is None: + valid = is_valid(data) + return np.nansum(valid, axis=axis, keepdims=keepdims) + if condition is True: + return np.nansum(np.ones_like(data), axis=axis, keepdims=keepdims) + if callable(condition): + count = condition(data) + return np.nansum(count, axis=axis, keepdims=keepdims) diff --git a/tests/test_arrays.py b/tests/test_arrays.py index 3d760214..c03883a8 100644 --- a/tests/test_arrays.py +++ b/tests/test_arrays.py @@ -459,3 +459,44 @@ def test_reduce_dimension( ) assert output_cube[0, 0, 0].data.compute().item() is True assert not output_cube[slice(1, None), :, :].data.compute().any() + + +@pytest.mark.parametrize("size", [(3, 3, 2, 4)]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_count(temporal_interval, bounding_box, random_raster_data, process_registry): + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03", "B04", "B08"], + backend="dask", + ) + + _process = partial( + process_registry["count"].implementation, + data=ParameterReference(from_parameter="data"), + ) + output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands") + general_output_checks( + input_cube=input_cube, + output_cube=output_cube, + verify_attrs=False, + verify_crs=True, + ) + assert output_cube.dims == ("x", "y", "t") + xr.testing.assert_equal(output_cube, xr.zeros_like(output_cube) + 4) + + _process = partial( + process_registry["count"].implementation, + data=ParameterReference(from_parameter="data"), + condition=True, + ) + output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands") + general_output_checks( + input_cube=input_cube, + output_cube=output_cube, + verify_attrs=False, + verify_crs=True, + ) + assert output_cube.dims == ("x", "y", "t") + xr.testing.assert_equal(output_cube, xr.zeros_like(output_cube) + 4)