diff --git a/openeogeotrellis/geopysparkdatacube.py b/openeogeotrellis/geopysparkdatacube.py index 9a6ff1f02..6a74a26cf 100644 --- a/openeogeotrellis/geopysparkdatacube.py +++ b/openeogeotrellis/geopysparkdatacube.py @@ -30,7 +30,7 @@ from openeo.udf import UdfData, run_udf_code from openeo.udf.xarraydatacube import XarrayDataCube, XarrayIO from openeo.util import dict_no_none, str_truncate -from openeo_driver.datacube import DriverDataCube +from openeo_driver.datacube import DriverDataCube, DriverVectorCube from openeo_driver.datastructs import ResolutionMergeArgs from openeo_driver.datastructs import SarBackscatterArgs from openeo_driver.delayed_vector import DelayedVector @@ -1246,14 +1246,14 @@ def raster_to_vector(self): def get_max_level(self): return self.pyramid.levels[self.pyramid.max_zoom] - def aggregate_spatial(self, geometries: Union[str, BaseGeometry], reducer, + def aggregate_spatial(self, geometries: Union[str, BaseGeometry, DriverVectorCube], reducer, target_dimension: str = "result") -> Union[AggregatePolygonResult, AggregateSpatialVectorCube]: if isinstance(reducer, dict): if len(reducer) == 1: single_process = next(iter(reducer.values())).get('process_id') - return self.zonal_statistics(geometries,single_process) + return self.zonal_statistics(geometries, single_process) else: visitor = GeotrellisTileProcessGraphVisitor(_builder=self._get_jvm().org.openeo.geotrellis.aggregate_polygon.SparkAggregateScriptBuilder()).accept_process_graph(reducer) return self.zonal_statistics(geometries, visitor.builder) @@ -1263,7 +1263,7 @@ def aggregate_spatial(self, geometries: Union[str, BaseGeometry], reducer, code="ReducerUnsupported", status_code=400 ) - def zonal_statistics(self, regions: Union[str, BaseGeometry], func) -> Union[AggregatePolygonResult, + def zonal_statistics(self, regions: Union[str, BaseGeometry, DriverVectorCube], func) -> Union[AggregatePolygonResult, AggregateSpatialVectorCube]: # TODO: rename to aggregate_spatial? # TODO eliminate code duplication @@ -1287,7 +1287,7 @@ def csv_dir() -> str: if isinstance(regions, (Polygon, MultiPolygon)): regions = GeometryCollection([regions]) - polygons = (None if isinstance(regions, Point) or + projected_polygons = (None if isinstance(regions, Point) or (isinstance(regions, GeometryCollection) and any(isinstance(geom, Point) for geom in regions.geoms)) else to_projected_polygons(self._get_jvm(), regions)) @@ -1322,18 +1322,18 @@ def csv_dir() -> str: from_date = insert_timezone(layer_metadata.bounds.minKey.instant) to_date = insert_timezone(layer_metadata.bounds.maxKey.instant) - if polygons: + if projected_polygons: # TODO also add dumping results first to temp json file like with "mean" if func == 'histogram': stats = self._compute_stats_geotrellis().compute_histograms_time_series_from_datacube( - scala_data_cube, polygons, from_date.isoformat(), to_date.isoformat(), 0 + scala_data_cube, projected_polygons, from_date.isoformat(), to_date.isoformat(), 0 ) timeseries = self._as_python(stats) elif func == "mean": with tempfile.NamedTemporaryFile(suffix=".json.tmp") as temp_file: self._compute_stats_geotrellis().compute_average_timeseries_from_datacube( scala_data_cube, - polygons, + projected_polygons, from_date.isoformat(), to_date.isoformat(), 0, @@ -1347,7 +1347,7 @@ def csv_dir() -> str: self._compute_stats_geotrellis().compute_generic_timeseries_from_datacube( func, wrapped, - polygons, + projected_polygons, temp_output ) return AggregatePolygonResultCSV(temp_output, regions=regions, metadata=self.metadata) diff --git a/openeogeotrellis/layercatalog.py b/openeogeotrellis/layercatalog.py index 7ced1e38c..89a59faa7 100644 --- a/openeogeotrellis/layercatalog.py +++ b/openeogeotrellis/layercatalog.py @@ -73,7 +73,8 @@ def create_datacube_parameters(self, load_params, env): getattr(datacubeParams, "layoutScheme_$eq")("FloatingLayoutScheme") return datacubeParams, single_level - @lru_cache(maxsize=20) + # FIXME: LoadParameters must be hashable but DriverVectorCube in aggregate_spatial_geometries isn't + # @lru_cache(maxsize=20) @TimingLogger(title="load_collection", logger=logger) def load_collection(self, collection_id: str, load_params: LoadParameters, env: EvalEnv) -> GeopysparkDataCube: logger.info("Creating layer for {c} with load params {p}".format(c=collection_id, p=load_params)) diff --git a/openeogeotrellis/utils.py b/openeogeotrellis/utils.py index 063e1f994..46f6fdfd5 100644 --- a/openeogeotrellis/utils.py +++ b/openeogeotrellis/utils.py @@ -19,6 +19,7 @@ from py4j.java_gateway import JavaGateway, JVMView from shapely.geometry import GeometryCollection, MultiPolygon, Polygon +from openeo_driver.datacube import DriverVectorCube from openeo_driver.delayed_vector import DelayedVector from openeo_driver.util.logging import (get_logging_config, setup_logging, user_id_trim, BatchJobLoggingFilter, FlaskRequestCorrelationIdLogging, FlaskUserIdLogging, LOGGING_CONTEXT_BATCH_JOB) @@ -193,6 +194,9 @@ def to_projected_polygons(jvm, *args): return jvm.org.openeo.geotrellis.ProjectedPolygons.fromVectorFile(str(args[0])) elif len(args) == 1 and isinstance(args[0], DelayedVector): return to_projected_polygons(jvm, args[0].path) + elif len(args) == 1 and isinstance(args[0], DriverVectorCube): + vc: DriverVectorCube = args[0] + return to_projected_polygons(jvm, GeometryCollection(list(vc.get_geometries())), str(vc.get_crs())) elif 1 <= len(args) <= 2 and isinstance(args[0], GeometryCollection): # Multiple polygons polygon_wkts = [str(x) for x in args[0].geoms] diff --git a/tests/data/geometries/FeatureCollection.geojson b/tests/data/geometries/FeatureCollection.geojson new file mode 100644 index 000000000..705107739 --- /dev/null +++ b/tests/data/geometries/FeatureCollection.geojson @@ -0,0 +1,70 @@ +{ + "type": "FeatureCollection", + "properties": {}, + "features": [ + { + "type": "Feature", + "id": "apples", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 0.17303466796874997, + 0.15930155257140266 + ], + [ + 0.65643310546875, + 0.15930155257140266 + ], + [ + 0.65643310546875, + 0.4394488164139768 + ], + [ + 0.17303466796874997, + 0.4394488164139768 + ], + [ + 0.17303466796874997, + 0.15930155257140266 + ] + ] + ] + } + }, + { + "type": "Feature", + "id": "oranges", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 1.5875244140625, + 1.5790847267832018 + ], + [ + 1.93084716796875, + 1.5790847267832018 + ], + [ + 1.93084716796875, + 1.9332268264771233 + ], + [ + 1.5875244140625, + 1.9332268264771233 + ], + [ + 1.5875244140625, + 1.5790847267832018 + ] + ] + ] + } + } + ] +} diff --git a/tests/test_api_result.py b/tests/test_api_result.py index 98cd83b7d..918629908 100644 --- a/tests/test_api_result.py +++ b/tests/test_api_result.py @@ -6,6 +6,7 @@ import numpy as np import pytest import rasterio +import xarray as xr from numpy.testing import assert_equal from openeo_driver.testing import TEST_USER @@ -13,6 +14,7 @@ from openeogeotrellis.testing import random_name from openeogeotrellis.utils import get_jvm, UtcNowClock +from tests.data import get_test_data_file _log = logging.getLogger(__name__) @@ -1324,3 +1326,61 @@ def test_apply_neighborhood_filter_spatial(api100, tmp_path): assert ds.width == 4 +def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path): + response = api100.check_result({ + 'loadcollection1': { + 'process_id': 'load_collection', + 'arguments': { + "id": "TestCollection-LonLat4x4", + "temporal_extent": ["2021-01-01", "2021-02-20"], + "spatial_extent": {"west": 0.0, "south": 0.0, "east": 2.0, "north": 2.0}, + "bands": ["Flat:1", "Month", "Day"] + } + }, + 'loaduploadedfiles1': { + 'process_id': 'load_uploaded_files', + 'arguments': { + 'format': 'GeoJSON', + 'paths': [str(get_test_data_file("geometries/FeatureCollection.geojson"))] + } + }, + 'aggregatespatial1': { + 'process_id': 'aggregate_spatial', + 'arguments': { + 'data': {'from_node': 'loadcollection1'}, + 'geometries': {'from_node': 'loaduploadedfiles1'}, + 'reducer': { + 'process_graph': { + 'mean1': { + 'process_id': 'mean', + 'arguments': { + 'data': {'from_parameter': 'data'} + }, + 'result': True + } + } + } + } + }, + "saveresult1": { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "aggregatespatial1"}, + "format": "netCDF" + }, + "result": True + } + }) + + response.assert_status_code(200) + + output_file = tmp_path / "test_aggregate_spatial_netcdf_feature_names.nc" + + with open(output_file, mode="wb") as f: + f.write(response.data) + + ds = xr.load_dataset(output_file) + assert ds["Flat:1"].sel(t='2021-02-05').values.tolist() == [1.0, 1.0] + assert ds["Month"].sel(t='2021-02-05').values.tolist() == [2.0, 2.0] + assert ds["Day"].sel(t='2021-02-05').values.tolist() == [5.0, 5.0] + assert ds.coords["feature_names"].values.tolist() == ["apples", "oranges"]