diff --git a/openeo/rest/_datacube.py b/openeo/rest/_datacube.py index 6a8ffe1da..79fe5d5ea 100644 --- a/openeo/rest/_datacube.py +++ b/openeo/rest/_datacube.py @@ -37,8 +37,10 @@ class _ProcessGraphAbstraction(_FromNodeMixin, FlatGraphableMixin): raster data cubes, vector cubes, ML models, ... """ - def __init__(self, pgnode: PGNode, connection: Connection): + def __init__(self, pgnode: PGNode, connection: Union[Connection, None]): self._pg = pgnode + # TODO: now that connection can officially be None: + # improve exceptions in cases where is it still assumed to be a real connection (download, create_job, ...) self._connection = connection def __str__(self): diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index fe80c79c0..246e5161a 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -2211,10 +2211,11 @@ def save_result( format: str = _DEFAULT_RASTER_FORMAT, options: Optional[dict] = None, ) -> DataCube: - formats = set(self._connection.list_output_formats().keys()) - # TODO: map format to correct casing too? - if format.lower() not in {f.lower() for f in formats}: - raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats)) + if self._connection: + formats = set(self._connection.list_output_formats().keys()) + # TODO: map format to correct casing too? + if format.lower() not in {f.lower() for f in formats}: + raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats)) return self.process( process_id="save_result", arguments={ diff --git a/openeo/rest/mlmodel.py b/openeo/rest/mlmodel.py index 0ddb92598..07162e7f5 100644 --- a/openeo/rest/mlmodel.py +++ b/openeo/rest/mlmodel.py @@ -27,7 +27,7 @@ class MlModel(_ProcessGraphAbstraction): .. versionadded:: 0.10.0 """ - def __init__(self, graph: PGNode, connection: Connection): + def __init__(self, graph: PGNode, connection: Union[Connection, None]): super().__init__(pgnode=graph, connection=connection) def save_ml_model(self, options: Optional[dict] = None): diff --git a/openeo/rest/vectorcube.py b/openeo/rest/vectorcube.py index 0e941b790..09b35f77e 100644 --- a/openeo/rest/vectorcube.py +++ b/openeo/rest/vectorcube.py @@ -40,7 +40,7 @@ class VectorCube(_ProcessGraphAbstraction): _DEFAULT_VECTOR_FORMAT = "GeoJSON" - def __init__(self, graph: PGNode, connection: Connection, metadata: Optional[CubeMetadata] = None): + def __init__(self, graph: PGNode, connection: Union[Connection, None], metadata: Optional[CubeMetadata] = None): super().__init__(pgnode=graph, connection=connection) self.metadata = metadata diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index 6b0da6e8b..83f82e488 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -83,18 +83,6 @@ def _get_leaf_node(cube, force_flat=True) -> dict: class TestDataCube: - def test_load_stac_connectionless(self, connection): - expected_graph = { - "loadstac1": { - "process_id": "load_stac", - "arguments": {"url": "https://provider.test/dataset"}, - "result": True, - } - } - cube = DataCube.load_stac("https://provider.test/dataset") - assert cube.flat_graph() == expected_graph - cube2 = connection.load_stac("https://provider.test/dataset") - assert cube2.flat_graph() == expected_graph def test_load_collection_connectionless_basic(self): cube = DataCube.load_collection("T3") @@ -148,6 +136,48 @@ def test_load_collection_connectionless_temporal_extent_shortcut(self): } } + def test_load_collection_connectionless_save_result(self): + cube = DataCube.load_collection("T3").save_result(format="GTiff") + assert cube.flat_graph() == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "T3", "spatial_extent": None, "temporal_extent": None}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": "GTiff", + "options": {}, + }, + "result": True, + }, + } + + def test_load_stac_connectionless_basic(self): + cube = DataCube.load_stac("https://provider.test/dataset") + assert cube.flat_graph() == { + "loadstac1": { + "process_id": "load_stac", + "arguments": {"url": "https://provider.test/dataset"}, + "result": True, + } + } + + def test_load_stac_connectionless_save_result(self): + cube = DataCube.load_stac("https://provider.test/dataset").save_result(format="GTiff") + assert cube.flat_graph() == { + "loadstac1": { + "process_id": "load_stac", + "arguments": {"url": "https://provider.test/dataset"}, + }, + "saveresult1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "loadstac1"}, "format": "GTiff", "options": {}}, + "result": True, + }, + } + def test_filter_temporal_basic_positional_args(s2cube): im = s2cube.filter_temporal("2016-01-01", "2016-03-10")