diff --git a/openeogeotrellis/layercatalog.py b/openeogeotrellis/layercatalog.py index 89a59faa7..7ced1e38c 100644 --- a/openeogeotrellis/layercatalog.py +++ b/openeogeotrellis/layercatalog.py @@ -73,8 +73,7 @@ def create_datacube_parameters(self, load_params, env): getattr(datacubeParams, "layoutScheme_$eq")("FloatingLayoutScheme") return datacubeParams, single_level - # FIXME: LoadParameters must be hashable but DriverVectorCube in aggregate_spatial_geometries isn't - # @lru_cache(maxsize=20) + @lru_cache(maxsize=20) @TimingLogger(title="load_collection", logger=logger) def load_collection(self, collection_id: str, load_params: LoadParameters, env: EvalEnv) -> GeopysparkDataCube: logger.info("Creating layer for {c} with load params {p}".format(c=collection_id, p=load_params)) diff --git a/tests/data/geometries/FeatureCollection02.json b/tests/data/geometries/FeatureCollection02.json new file mode 100644 index 000000000..1df4ed298 --- /dev/null +++ b/tests/data/geometries/FeatureCollection02.json @@ -0,0 +1,41 @@ +{ + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": { + "id": "first", + "pop": 1234 + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [1, 1], + [3, 1], + [2, 3], + [1, 1] + ] + ] + } + }, + { + "type": "Feature", + "properties": { + "id": "second", + "pop": 5678 + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [4, 2], + [5, 4], + [3, 4], + [4, 2] + ] + ] + } + } + ] +} diff --git a/tests/test_api_result.py b/tests/test_api_result.py index 918629908..6c76bde31 100644 --- a/tests/test_api_result.py +++ b/tests/test_api_result.py @@ -1,5 +1,6 @@ import contextlib import logging +import mock import textwrap from typing import List @@ -1372,8 +1373,6 @@ def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path): } }) - response.assert_status_code(200) - output_file = tmp_path / "test_aggregate_spatial_netcdf_feature_names.nc" with open(output_file, mode="wb") as f: @@ -1384,3 +1383,82 @@ def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path): assert ds["Month"].sel(t='2021-02-05').values.tolist() == [2.0, 2.0] assert ds["Day"].sel(t='2021-02-05').values.tolist() == [5.0, 5.0] assert ds.coords["feature_names"].values.tolist() == ["apples", "oranges"] + + +def test_load_collection_is_cached(api100): + # unflattening this process graph will result in two calls to load_collection, unless it is cached + + process_graph = { + 'loadcollection1': { + 'process_id': 'load_collection', + 'arguments': { + "id": "TestCollection-LonLat4x4", + "temporal_extent": ["2021-01-01", "2021-02-20"], + "spatial_extent": {"west": 0.0, "south": 0.0, "east": 2.0, "north": 2.0}, + "bands": ["Flat:1", "Month"] + } + }, + 'filterbands1': { + 'process_id': 'filter_bands', + 'arguments': { + 'bands': ['Flat:1'], + 'data': {'from_node': 'loadcollection1'} + } + }, + 'filterbands2': { + 'process_id': 'filter_bands', + 'arguments': { + 'bands': ['Month'], + 'data': {'from_node': 'loadcollection1'} + } + }, + 'mergecubes1': { + 'process_id': 'merge_cubes', + 'arguments': { + 'cube1': {'from_node': 'filterbands1'}, + 'cube2': {'from_node': 'filterbands2'} + } + }, + 'loaduploadedfiles1': { + 'process_id': 'load_uploaded_files', + 'arguments': { + 'format': 'GeoJSON', + 'paths': [str(get_test_data_file("geometries/FeatureCollection.geojson"))] + } + }, + 'aggregatespatial1': { + 'process_id': 'aggregate_spatial', + 'arguments': { + 'data': {'from_node': 'mergecubes1'}, + 'geometries': {'from_node': 'loaduploadedfiles1'}, + 'reducer': { + 'process_graph': { + 'mean1': { + 'process_id': 'mean', + 'arguments': { + 'data': {'from_parameter': 'data'} + }, + 'result': True} + } + } + }, 'result': True + } + } + + with mock.patch('openeogeotrellis.layercatalog.logger') as logger: + result = api100.check_result(process_graph).json + + assert result == { + "2021-01-05T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]], + "2021-01-15T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]], + "2021-01-25T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]], + "2021-02-05T00:00:00Z": [[1.0, 2.0], [1.0, 2.0]], + "2021-02-15T00:00:00Z": [[1.0, 2.0], [1.0, 2.0]], + } + + # TODO: is there an easier way to count the calls to lru_cache-decorated function load_collection? + creating_layer_calls = list(filter(lambda call: call.args[0].startswith("Creating layer for TestCollection-LonLat4x4"), + logger.info.call_args_list)) + + n_load_collection_calls = len(creating_layer_calls) + assert n_load_collection_calls == 1 diff --git a/tests/test_load_collection.py b/tests/test_load_collection.py index da11b6499..99363d20f 100644 --- a/tests/test_load_collection.py +++ b/tests/test_load_collection.py @@ -3,13 +3,16 @@ from mock import MagicMock, ANY from openeo_driver.backend import LoadParameters +from openeo_driver.datacube import DriverVectorCube from openeo_driver.datastructs import SarBackscatterArgs from openeo_driver.errors import OpenEOApiException from openeo_driver.utils import EvalEnv from py4j.java_gateway import JavaGateway +from tests.data import get_test_data_file from openeogeotrellis.geopysparkdatacube import GeopysparkDataCube from openeogeotrellis.layercatalog import get_layer_catalog +import geopandas as gpd import geopyspark as gps from .test_api_result import CreoApiMocker, TerrascopeApiMocker @@ -333,4 +336,30 @@ def test_load_disk_collection_batch(imagecollection_with_two_bands_and_three_dat assert len(cube.metadata.spatial_dimensions) == 2 assert len(cube.pyramid.levels)==1 - print(cube.get_max_level().layer_metadata) \ No newline at end of file + print(cube.get_max_level().layer_metadata) + + +def test_driver_vector_cube_supports_load_collection_caching(jvm_mock): + catalog = get_layer_catalog() + + def load_params1(): + gdf = gpd.read_file(str(get_test_data_file("geometries/FeatureCollection.geojson"))) + return LoadParameters(aggregate_spatial_geometries=DriverVectorCube(gdf)) + + def load_params2(): + gdf = gpd.read_file(str(get_test_data_file("geometries/FeatureCollection02.json"))) + return LoadParameters(aggregate_spatial_geometries=DriverVectorCube(gdf)) + + with mock.patch('openeogeotrellis.layercatalog.logger') as logger: + catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'})) + catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'})) + catalog.load_collection('SENTINEL1_GRD', load_params=load_params2(), env=EvalEnv({'pyramid_levels': 'highest'})) + catalog.load_collection('SENTINEL1_GRD', load_params=load_params2(), env=EvalEnv({'pyramid_levels': 'highest'})) + catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'})) + + # TODO: is there an easier way to count the calls to lru_cache-decorated function load_collection? + creating_layer_calls = list(filter(lambda call: call.args[0].startswith("Creating layer for SENTINEL1_GRD"), + logger.info.call_args_list)) + + n_load_collection_calls = len(creating_layer_calls) + assert n_load_collection_calls == 2