Skip to content

Commit

Permalink
2023.10.3 (#173)
Browse files Browse the repository at this point in the history
* update prediction

* update tests for code cov

* handle one band
  • Loading branch information
ValentinaHutter authored Oct 11, 2023
1 parent 6e77c76 commit 2c8f1cc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
12 changes: 10 additions & 2 deletions openeo_processes_dask/process_implementations/ml/curve_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def fit_curve(
raise DimensionNotAvailable(
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}"
)
bands_required = False
if "bands" in data.dims:
if len(data["bands"].values) == 1:
bands_required = data["bands"].values[0]

try:
# Try parsing as datetime first
Expand Down Expand Up @@ -81,11 +85,15 @@ def _wrap(*args, **kwargs):
.drop_dims(["cov_i", "cov_j"])
.to_array()
.squeeze()
.transpose(*expected_dims_after)
)

fit_result.attrs = data.attrs
fit_result = fit_result.rio.write_crs(rechunked_data.rio.crs)
if bands_required and not "bands" in fit_result.dims:
fit_result = fit_result.assign_coords(**{"bands": bands_required})
fit_result = fit_result.expand_dims(dim="bands")

fit_result = fit_result.transpose(*expected_dims_after)

return fit_result

Expand All @@ -99,6 +107,7 @@ def predict_curve(
):
labels_were_datetime = False
dims_before = list(parameters.dims)
initial_labels = labels

try:
# Try parsing as datetime first
Expand All @@ -108,7 +117,6 @@ def predict_curve(

if np.issubdtype(labels.dtype, np.datetime64):
labels_were_datetime = True
initial_labels = labels
timestep = [
(
(np.datetime64(x) - np.datetime64("1970-01-01", "s"))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openeo-processes-dask"
version = "2023.10.2"
version = "2023.10.3"
description = "Python implementations of many OpenEO processes, dask-friendly by default."
authors = ["Lukas Weidenholzer <[email protected]>", "Sean Hoyal <[email protected]>", "Valentina Hutter <[email protected]>"]
maintainers = ["EODC Staff <[email protected]>"]
Expand Down
10 changes: 8 additions & 2 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ def fitFunction(x, parameters):
assert len(result.coords["y"]) == len(origin_cube.coords["y"])
assert len(result.coords["param"]) == len(parameters)

origin_cube_B02 = origin_cube.sel(bands=["B02"])
result_B02 = fit_curve(
origin_cube_B02, parameters=parameters, function=_process, dimension="t"
)
assert "bands" in result_B02.dims
assert result_B02["bands"].values == "B02"

labels = dimension_labels(origin_cube, origin_cube.openeo.temporal_dims[0])
labels = [float(l) for l in labels]
predictions = predict_curve(
result,
_process,
Expand All @@ -96,7 +102,7 @@ def fitFunction(x, parameters):
assert "param" not in predictions.dims
assert result.rio.crs == predictions.rio.crs

labels = ["2020-02-02", "2020-03-02", "2020-04-02", "2020-05-02"]
labels = [0, 1, 2, 3]
predictions = predict_curve(
result,
_process,
Expand Down

0 comments on commit 2c8f1cc

Please sign in to comment.