Skip to content

Commit

Permalink
Support save_result with connectionless data cubes
Browse files Browse the repository at this point in the history
follow up of #638 #639
  • Loading branch information
soxofaan committed Oct 4, 2024
1 parent 744fbd6 commit 08fed55
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 19 deletions.
4 changes: 3 additions & 1 deletion openeo/rest/_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class _ProcessGraphAbstraction(_FromNodeMixin, FlatGraphableMixin):
raster data cubes, vector cubes, ML models, ...
"""

def __init__(self, pgnode: PGNode, connection: Connection):
def __init__(self, pgnode: PGNode, connection: Union[Connection, None]):
self._pg = pgnode
# TODO: now that connection can officially be None:
# improve exceptions in cases where is it still assumed to be a real connection (download, create_job, ...)
self._connection = connection

def __str__(self):
Expand Down
9 changes: 5 additions & 4 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2211,10 +2211,11 @@ def save_result(
format: str = _DEFAULT_RASTER_FORMAT,
options: Optional[dict] = None,
) -> DataCube:
formats = set(self._connection.list_output_formats().keys())
# TODO: map format to correct casing too?
if format.lower() not in {f.lower() for f in formats}:
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
if self._connection:
formats = set(self._connection.list_output_formats().keys())
# TODO: map format to correct casing too?
if format.lower() not in {f.lower() for f in formats}:
raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats))
return self.process(
process_id="save_result",
arguments={
Expand Down
2 changes: 1 addition & 1 deletion openeo/rest/mlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MlModel(_ProcessGraphAbstraction):
.. versionadded:: 0.10.0
"""

def __init__(self, graph: PGNode, connection: Connection):
def __init__(self, graph: PGNode, connection: Union[Connection, None]):
super().__init__(pgnode=graph, connection=connection)

def save_ml_model(self, options: Optional[dict] = None):
Expand Down
2 changes: 1 addition & 1 deletion openeo/rest/vectorcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class VectorCube(_ProcessGraphAbstraction):

_DEFAULT_VECTOR_FORMAT = "GeoJSON"

def __init__(self, graph: PGNode, connection: Connection, metadata: Optional[CubeMetadata] = None):
def __init__(self, graph: PGNode, connection: Union[Connection, None], metadata: Optional[CubeMetadata] = None):
super().__init__(pgnode=graph, connection=connection)
self.metadata = metadata

Expand Down
54 changes: 42 additions & 12 deletions tests/rest/datacube/test_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,6 @@ def _get_leaf_node(cube, force_flat=True) -> dict:


class TestDataCube:
def test_load_stac_connectionless(self, connection):
expected_graph = {
"loadstac1": {
"process_id": "load_stac",
"arguments": {"url": "https://provider.test/dataset"},
"result": True,
}
}
cube = DataCube.load_stac("https://provider.test/dataset")
assert cube.flat_graph() == expected_graph
cube2 = connection.load_stac("https://provider.test/dataset")
assert cube2.flat_graph() == expected_graph

def test_load_collection_connectionless_basic(self):
cube = DataCube.load_collection("T3")
Expand Down Expand Up @@ -148,6 +136,48 @@ def test_load_collection_connectionless_temporal_extent_shortcut(self):
}
}

def test_load_collection_connectionless_save_result(self):
cube = DataCube.load_collection("T3").save_result(format="GTiff")
assert cube.flat_graph() == {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {"id": "T3", "spatial_extent": None, "temporal_extent": None},
},
"saveresult1": {
"process_id": "save_result",
"arguments": {
"data": {"from_node": "loadcollection1"},
"format": "GTiff",
"options": {},
},
"result": True,
},
}

def test_load_stac_connectionless_basic(self):
cube = DataCube.load_stac("https://provider.test/dataset")
assert cube.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {"url": "https://provider.test/dataset"},
"result": True,
}
}

def test_load_stac_connectionless_save_result(self):
cube = DataCube.load_stac("https://provider.test/dataset").save_result(format="GTiff")
assert cube.flat_graph() == {
"loadstac1": {
"process_id": "load_stac",
"arguments": {"url": "https://provider.test/dataset"},
},
"saveresult1": {
"process_id": "save_result",
"arguments": {"data": {"from_node": "loadstac1"}, "format": "GTiff", "options": {}},
"result": True,
},
}


def test_filter_temporal_basic_positional_args(s2cube):
im = s2cube.filter_temporal("2016-01-01", "2016-03-10")
Expand Down

0 comments on commit 08fed55

Please sign in to comment.