Skip to content

Commit

Permalink
support DriverVectorCube in aggregate_spatial
Browse files Browse the repository at this point in the history
  • Loading branch information
bossie committed Sep 23, 2022
1 parent 5e10c6b commit 7d3c35a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 10 deletions.
18 changes: 9 additions & 9 deletions openeogeotrellis/geopysparkdatacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from openeo.udf import UdfData, run_udf_code
from openeo.udf.xarraydatacube import XarrayDataCube, XarrayIO
from openeo.util import dict_no_none, str_truncate
from openeo_driver.datacube import DriverDataCube
from openeo_driver.datacube import DriverDataCube, DriverVectorCube
from openeo_driver.datastructs import ResolutionMergeArgs
from openeo_driver.datastructs import SarBackscatterArgs
from openeo_driver.delayed_vector import DelayedVector
Expand Down Expand Up @@ -1246,14 +1246,14 @@ def raster_to_vector(self):
def get_max_level(self):
return self.pyramid.levels[self.pyramid.max_zoom]

def aggregate_spatial(self, geometries: Union[str, BaseGeometry], reducer,
def aggregate_spatial(self, geometries: Union[str, BaseGeometry, DriverVectorCube], reducer,
target_dimension: str = "result") -> Union[AggregatePolygonResult,
AggregateSpatialVectorCube]:

if isinstance(reducer, dict):
if len(reducer) == 1:
single_process = next(iter(reducer.values())).get('process_id')
return self.zonal_statistics(geometries,single_process)
return self.zonal_statistics(geometries, single_process)
else:
visitor = GeotrellisTileProcessGraphVisitor(_builder=self._get_jvm().org.openeo.geotrellis.aggregate_polygon.SparkAggregateScriptBuilder()).accept_process_graph(reducer)
return self.zonal_statistics(geometries, visitor.builder)
Expand All @@ -1263,7 +1263,7 @@ def aggregate_spatial(self, geometries: Union[str, BaseGeometry], reducer,
code="ReducerUnsupported", status_code=400
)

def zonal_statistics(self, regions: Union[str, BaseGeometry], func) -> Union[AggregatePolygonResult,
def zonal_statistics(self, regions: Union[str, BaseGeometry, DriverVectorCube], func) -> Union[AggregatePolygonResult,
AggregateSpatialVectorCube]:
# TODO: rename to aggregate_spatial?
# TODO eliminate code duplication
Expand All @@ -1287,7 +1287,7 @@ def csv_dir() -> str:
if isinstance(regions, (Polygon, MultiPolygon)):
regions = GeometryCollection([regions])

polygons = (None if isinstance(regions, Point) or
projected_polygons = (None if isinstance(regions, Point) or
(isinstance(regions, GeometryCollection) and
any(isinstance(geom, Point) for geom in regions.geoms))
else to_projected_polygons(self._get_jvm(), regions))
Expand Down Expand Up @@ -1322,18 +1322,18 @@ def csv_dir() -> str:
from_date = insert_timezone(layer_metadata.bounds.minKey.instant)
to_date = insert_timezone(layer_metadata.bounds.maxKey.instant)

if polygons:
if projected_polygons:
# TODO also add dumping results first to temp json file like with "mean"
if func == 'histogram':
stats = self._compute_stats_geotrellis().compute_histograms_time_series_from_datacube(
scala_data_cube, polygons, from_date.isoformat(), to_date.isoformat(), 0
scala_data_cube, projected_polygons, from_date.isoformat(), to_date.isoformat(), 0
)
timeseries = self._as_python(stats)
elif func == "mean":
with tempfile.NamedTemporaryFile(suffix=".json.tmp") as temp_file:
self._compute_stats_geotrellis().compute_average_timeseries_from_datacube(
scala_data_cube,
polygons,
projected_polygons,
from_date.isoformat(),
to_date.isoformat(),
0,
Expand All @@ -1347,7 +1347,7 @@ def csv_dir() -> str:
self._compute_stats_geotrellis().compute_generic_timeseries_from_datacube(
func,
wrapped,
polygons,
projected_polygons,
temp_output
)
return AggregatePolygonResultCSV(temp_output, regions=regions, metadata=self.metadata)
Expand Down
3 changes: 2 additions & 1 deletion openeogeotrellis/layercatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def create_datacube_parameters(self, load_params, env):
getattr(datacubeParams, "layoutScheme_$eq")("FloatingLayoutScheme")
return datacubeParams, single_level

@lru_cache(maxsize=20)
# FIXME: LoadParameters must be hashable but DriverVectorCube in aggregate_spatial_geometries isn't
# @lru_cache(maxsize=20)
@TimingLogger(title="load_collection", logger=logger)
def load_collection(self, collection_id: str, load_params: LoadParameters, env: EvalEnv) -> GeopysparkDataCube:
logger.info("Creating layer for {c} with load params {p}".format(c=collection_id, p=load_params))
Expand Down
4 changes: 4 additions & 0 deletions openeogeotrellis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from py4j.java_gateway import JavaGateway, JVMView
from shapely.geometry import GeometryCollection, MultiPolygon, Polygon

from openeo_driver.datacube import DriverVectorCube
from openeo_driver.delayed_vector import DelayedVector
from openeo_driver.util.logging import (get_logging_config, setup_logging, user_id_trim, BatchJobLoggingFilter,
FlaskRequestCorrelationIdLogging, FlaskUserIdLogging, LOGGING_CONTEXT_BATCH_JOB)
Expand Down Expand Up @@ -193,6 +194,9 @@ def to_projected_polygons(jvm, *args):
return jvm.org.openeo.geotrellis.ProjectedPolygons.fromVectorFile(str(args[0]))
elif len(args) == 1 and isinstance(args[0], DelayedVector):
return to_projected_polygons(jvm, args[0].path)
elif len(args) == 1 and isinstance(args[0], DriverVectorCube):
vc: DriverVectorCube = args[0]
return to_projected_polygons(jvm, GeometryCollection(list(vc.get_geometries())), str(vc.get_crs()))
elif 1 <= len(args) <= 2 and isinstance(args[0], GeometryCollection):
# Multiple polygons
polygon_wkts = [str(x) for x in args[0].geoms]
Expand Down
70 changes: 70 additions & 0 deletions tests/data/geometries/FeatureCollection.geojson
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
"type": "FeatureCollection",
"properties": {},
"features": [
{
"type": "Feature",
"id": "apples",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[
0.17303466796874997,
0.15930155257140266
],
[
0.65643310546875,
0.15930155257140266
],
[
0.65643310546875,
0.4394488164139768
],
[
0.17303466796874997,
0.4394488164139768
],
[
0.17303466796874997,
0.15930155257140266
]
]
]
}
},
{
"type": "Feature",
"id": "oranges",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[
1.5875244140625,
1.5790847267832018
],
[
1.93084716796875,
1.5790847267832018
],
[
1.93084716796875,
1.9332268264771233
],
[
1.5875244140625,
1.9332268264771233
],
[
1.5875244140625,
1.5790847267832018
]
]
]
}
}
]
}
60 changes: 60 additions & 0 deletions tests/test_api_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import numpy as np
import pytest
import rasterio
import xarray as xr
from numpy.testing import assert_equal

from openeo_driver.testing import TEST_USER
from shapely.geometry import box, mapping

from openeogeotrellis.testing import random_name
from openeogeotrellis.utils import get_jvm, UtcNowClock
from tests.data import get_test_data_file

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -1324,3 +1326,61 @@ def test_apply_neighborhood_filter_spatial(api100, tmp_path):
assert ds.width == 4


def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path):
response = api100.check_result({
'loadcollection1': {
'process_id': 'load_collection',
'arguments': {
"id": "TestCollection-LonLat4x4",
"temporal_extent": ["2021-01-01", "2021-02-20"],
"spatial_extent": {"west": 0.0, "south": 0.0, "east": 2.0, "north": 2.0},
"bands": ["Flat:1", "Month", "Day"]
}
},
'loaduploadedfiles1': {
'process_id': 'load_uploaded_files',
'arguments': {
'format': 'GeoJSON',
'paths': [str(get_test_data_file("geometries/FeatureCollection.geojson"))]
}
},
'aggregatespatial1': {
'process_id': 'aggregate_spatial',
'arguments': {
'data': {'from_node': 'loadcollection1'},
'geometries': {'from_node': 'loaduploadedfiles1'},
'reducer': {
'process_graph': {
'mean1': {
'process_id': 'mean',
'arguments': {
'data': {'from_parameter': 'data'}
},
'result': True
}
}
}
}
},
"saveresult1": {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "aggregatespatial1"},
"format": "netCDF"
},
"result": True
}
})

response.assert_status_code(200)

output_file = tmp_path / "test_aggregate_spatial_netcdf_feature_names.nc"

with open(output_file, mode="wb") as f:
f.write(response.data)

ds = xr.load_dataset(output_file)
assert ds["Flat:1"].sel(t='2021-02-05').values.tolist() == [1.0, 1.0]
assert ds["Month"].sel(t='2021-02-05').values.tolist() == [2.0, 2.0]
assert ds["Day"].sel(t='2021-02-05').values.tolist() == [5.0, 5.0]
assert ds.coords["feature_names"].values.tolist() == ["apples", "oranges"]

0 comments on commit 7d3c35a

Please sign in to comment.