Skip to content

Commit

Permalink
assert load_collection is cached
Browse files Browse the repository at this point in the history
  • Loading branch information
bossie committed Sep 23, 2022
1 parent 7d3c35a commit cd9747a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 5 deletions.
3 changes: 1 addition & 2 deletions openeogeotrellis/layercatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def create_datacube_parameters(self, load_params, env):
getattr(datacubeParams, "layoutScheme_$eq")("FloatingLayoutScheme")
return datacubeParams, single_level

# FIXME: LoadParameters must be hashable but DriverVectorCube in aggregate_spatial_geometries isn't
# @lru_cache(maxsize=20)
@lru_cache(maxsize=20)
@TimingLogger(title="load_collection", logger=logger)
def load_collection(self, collection_id: str, load_params: LoadParameters, env: EvalEnv) -> GeopysparkDataCube:
logger.info("Creating layer for {c} with load params {p}".format(c=collection_id, p=load_params))
Expand Down
41 changes: 41 additions & 0 deletions tests/data/geometries/FeatureCollection02.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"properties": {
"id": "first",
"pop": 1234
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[1, 1],
[3, 1],
[2, 3],
[1, 1]
]
]
}
},
{
"type": "Feature",
"properties": {
"id": "second",
"pop": 5678
},
"geometry": {
"type": "Polygon",
"coordinates": [
[
[4, 2],
[5, 4],
[3, 4],
[4, 2]
]
]
}
}
]
}
82 changes: 80 additions & 2 deletions tests/test_api_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import logging
import mock
import textwrap
from typing import List

Expand Down Expand Up @@ -1372,8 +1373,6 @@ def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path):
}
})

response.assert_status_code(200)

output_file = tmp_path / "test_aggregate_spatial_netcdf_feature_names.nc"

with open(output_file, mode="wb") as f:
Expand All @@ -1384,3 +1383,82 @@ def test_aggregate_spatial_netcdf_feature_names(api100, tmp_path):
assert ds["Month"].sel(t='2021-02-05').values.tolist() == [2.0, 2.0]
assert ds["Day"].sel(t='2021-02-05').values.tolist() == [5.0, 5.0]
assert ds.coords["feature_names"].values.tolist() == ["apples", "oranges"]


def test_load_collection_is_cached(api100):
# unflattening this process graph will result in two calls to load_collection, unless it is cached

process_graph = {
'loadcollection1': {
'process_id': 'load_collection',
'arguments': {
"id": "TestCollection-LonLat4x4",
"temporal_extent": ["2021-01-01", "2021-02-20"],
"spatial_extent": {"west": 0.0, "south": 0.0, "east": 2.0, "north": 2.0},
"bands": ["Flat:1", "Month"]
}
},
'filterbands1': {
'process_id': 'filter_bands',
'arguments': {
'bands': ['Flat:1'],
'data': {'from_node': 'loadcollection1'}
}
},
'filterbands2': {
'process_id': 'filter_bands',
'arguments': {
'bands': ['Month'],
'data': {'from_node': 'loadcollection1'}
}
},
'mergecubes1': {
'process_id': 'merge_cubes',
'arguments': {
'cube1': {'from_node': 'filterbands1'},
'cube2': {'from_node': 'filterbands2'}
}
},
'loaduploadedfiles1': {
'process_id': 'load_uploaded_files',
'arguments': {
'format': 'GeoJSON',
'paths': [str(get_test_data_file("geometries/FeatureCollection.geojson"))]
}
},
'aggregatespatial1': {
'process_id': 'aggregate_spatial',
'arguments': {
'data': {'from_node': 'mergecubes1'},
'geometries': {'from_node': 'loaduploadedfiles1'},
'reducer': {
'process_graph': {
'mean1': {
'process_id': 'mean',
'arguments': {
'data': {'from_parameter': 'data'}
},
'result': True}
}
}
}, 'result': True
}
}

with mock.patch('openeogeotrellis.layercatalog.logger') as logger:
result = api100.check_result(process_graph).json

assert result == {
"2021-01-05T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]],
"2021-01-15T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]],
"2021-01-25T00:00:00Z": [[1.0, 1.0], [1.0, 1.0]],
"2021-02-05T00:00:00Z": [[1.0, 2.0], [1.0, 2.0]],
"2021-02-15T00:00:00Z": [[1.0, 2.0], [1.0, 2.0]],
}

# TODO: is there an easier way to count the calls to lru_cache-decorated function load_collection?
creating_layer_calls = list(filter(lambda call: call.args[0].startswith("Creating layer for TestCollection-LonLat4x4"),
logger.info.call_args_list))

n_load_collection_calls = len(creating_layer_calls)
assert n_load_collection_calls == 1
31 changes: 30 additions & 1 deletion tests/test_load_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from mock import MagicMock, ANY

from openeo_driver.backend import LoadParameters
from openeo_driver.datacube import DriverVectorCube
from openeo_driver.datastructs import SarBackscatterArgs
from openeo_driver.errors import OpenEOApiException
from openeo_driver.utils import EvalEnv
from py4j.java_gateway import JavaGateway
from tests.data import get_test_data_file

from openeogeotrellis.geopysparkdatacube import GeopysparkDataCube
from openeogeotrellis.layercatalog import get_layer_catalog
import geopandas as gpd
import geopyspark as gps

from .test_api_result import CreoApiMocker, TerrascopeApiMocker
Expand Down Expand Up @@ -333,4 +336,30 @@ def test_load_disk_collection_batch(imagecollection_with_two_bands_and_three_dat

assert len(cube.metadata.spatial_dimensions) == 2
assert len(cube.pyramid.levels)==1
print(cube.get_max_level().layer_metadata)
print(cube.get_max_level().layer_metadata)


def test_driver_vector_cube_supports_load_collection_caching(jvm_mock):
catalog = get_layer_catalog()

def load_params1():
gdf = gpd.read_file(str(get_test_data_file("geometries/FeatureCollection.geojson")))
return LoadParameters(aggregate_spatial_geometries=DriverVectorCube(gdf))

def load_params2():
gdf = gpd.read_file(str(get_test_data_file("geometries/FeatureCollection02.json")))
return LoadParameters(aggregate_spatial_geometries=DriverVectorCube(gdf))

with mock.patch('openeogeotrellis.layercatalog.logger') as logger:
catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'}))
catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'}))
catalog.load_collection('SENTINEL1_GRD', load_params=load_params2(), env=EvalEnv({'pyramid_levels': 'highest'}))
catalog.load_collection('SENTINEL1_GRD', load_params=load_params2(), env=EvalEnv({'pyramid_levels': 'highest'}))
catalog.load_collection('SENTINEL1_GRD', load_params=load_params1(), env=EvalEnv({'pyramid_levels': 'highest'}))

# TODO: is there an easier way to count the calls to lru_cache-decorated function load_collection?
creating_layer_calls = list(filter(lambda call: call.args[0].startswith("Creating layer for SENTINEL1_GRD"),
logger.info.call_args_list))

n_load_collection_calls = len(creating_layer_calls)
assert n_load_collection_calls == 2

0 comments on commit cd9747a

Please sign in to comment.