Skip to content

Commit

Permalink
add is_nodata, is_nan (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
ValentinaHutter authored Nov 18, 2024
1 parent 84c24dd commit 39f8630
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
12 changes: 12 additions & 0 deletions openeo_processes_dask/process_implementations/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
__all__ = [
"is_infinite",
"is_valid",
"is_nan",
"is_nodata",
"eq",
"neq",
"gt",
Expand All @@ -35,6 +37,16 @@ def is_valid(x: ArrayLike):
return np.logical_and(notnull(x), finite)


def is_nodata(x: ArrayLike):
return x is None


def is_nan(x: ArrayLike):
if is_nodata(x):
return is_nodata(x)
return np.isnan(x)


def eq(
x: ArrayLike,
y: ArrayLike,
Expand Down
35 changes: 28 additions & 7 deletions tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@
from openeo_pg_parser_networkx.pg_schema import ParameterReference

from openeo_processes_dask.process_implementations import merge_cubes
from openeo_processes_dask.process_implementations.comparison import (
between,
eq,
is_infinite,
is_valid,
neq,
)
from openeo_processes_dask.process_implementations.comparison import *
from openeo_processes_dask.process_implementations.cubes.apply import apply
from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension
from tests.general_checks import assert_numpy_equals_dask_numpy, general_output_checks
Expand Down Expand Up @@ -73,6 +67,33 @@ def test_is_inf(value, expected, is_dask):
assert hasattr(output, "dask")


@pytest.mark.parametrize(
"value,expected",
[
(1, False),
(np.nan, True),
],
)
def test_is_nan(value, expected):
value = np.asarray(value)

is_dask = da.from_array(value)

output = is_nan(value)
np.testing.assert_array_equal(output, expected)

assert hasattr(is_nan(is_dask), "dask")


@pytest.mark.parametrize(
"value,expected",
[(1, False), ("Test", False), (None, True), ([np.nan, np.nan], False)],
)
def test_is_nodata(value, expected):
output = is_nodata(value)
np.testing.assert_array_equal(output, expected)


@pytest.mark.parametrize(
"x, y, delta, case_sensitive",
[
Expand Down

0 comments on commit 39f8630

Please sign in to comment.