diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index c724c6518..c64cffb7c 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -58,6 +58,9 @@ class DataCube(_ProcessGraphAbstraction): and this process graph can be "grown" to a desired workflow by calling the appropriate methods. """ + # TODO: set this based on back-end or user preference? + _DEFAULT_RASTER_FORMAT = "GTiff" + def __init__(self, graph: PGNode, connection: 'openeo.Connection', metadata: CollectionMetadata = None): super().__init__(pgnode=graph, connection=connection) self.metadata = CollectionMetadata.get_or_create(metadata) @@ -1810,8 +1813,11 @@ def atmospheric_correction( }) @openeo_process - def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube': + def save_result( + self, format: str = _DEFAULT_RASTER_FORMAT, options: Optional[dict] = None + ) -> "DataCube": formats = set(self._connection.list_output_formats().keys()) + # TODO: map format to correct casing too? if format.lower() not in {f.lower() for f in formats}: raise ValueError("Invalid format {f!r}. Should be one of {s}".format(f=format, s=formats)) return self.process( @@ -1819,27 +1825,29 @@ def save_result(self, format: str = "GTiff", options: dict = None) -> 'DataCube' arguments={ "data": THIS, "format": format, + # TODO: leave out options if unset? "options": options or {} } ) - def download( - self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None, - options: Optional[dict] = None - ): + def _ensure_save_result( + self, format: Optional[str] = None, options: Optional[dict] = None + ) -> "DataCube": """ - Download image collection, e.g. as GeoTIFF. - If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned. - The bytes object can be passed on to a suitable decoder for decoding. + Make sure there is a (final) `save_result` node in the process graph. + If there is already one: check if it is consistent with the given format/options (if any) + and add a new one otherwise. - :param outputfile: Optional, an output file if the result needs to be stored on disk. - :param format: Optional, an output format supported by the backend. - :param options: Optional, file format options - :return: None if the result is stored to disk, or a bytes object returned by the backend. + :param format: (optional) desired `save_result` file format + :param options: (optional) desired `save_result` file format parameters + :return: """ - if self.result_node().process_id == "save_result": - # There is already a `save_result` node: check if it is consistent with given format/options - args = self.result_node().arguments + # TODO: move to generic data cube parent class (not only for raster cubes, but also vector cubes) + result_node = self.result_node() + if result_node.process_id == "save_result": + # There is already a `save_result` node: + # check if it is consistent with given format/options (if any) + args = result_node.arguments if format is not None and format.lower() != args["format"].lower(): raise ValueError( f"Existing `save_result` node with different format {args['format']!r} != {format!r}" @@ -1851,10 +1859,28 @@ def download( cube = self else: # No `save_result` node yet: automatically add it. - if not format: - format = guess_format(outputfile) if outputfile else "GTiff" - cube = self.save_result(format=format, options=options) + cube = self.save_result( + format=format or self._DEFAULT_RASTER_FORMAT, options=options + ) + return cube + def download( + self, outputfile: Union[str, pathlib.Path, None] = None, format: Optional[str] = None, + options: Optional[dict] = None + ): + """ + Download image collection, e.g. as GeoTIFF. + If outputfile is provided, the result is stored on disk locally, otherwise, a bytes object is returned. + The bytes object can be passed on to a suitable decoder for decoding. + + :param outputfile: Optional, an output file if the result needs to be stored on disk. + :param format: Optional, an output format supported by the backend. + :param options: Optional, file format options + :return: None if the result is stored to disk, or a bytes object returned by the backend. + """ + if format is None and outputfile is not None: + format = guess_format(outputfile) + cube = self._ensure_save_result(format=format, options=options) return self._connection.download(cube.flat_graph(), outputfile) def validate(self) -> List[dict]: @@ -1869,27 +1895,36 @@ def tiled_viewing_service(self, type: str, **kwargs) -> Service: return self._connection.create_service(self.flat_graph(), type=type, **kwargs) def execute_batch( - self, - outputfile: Union[str, pathlib.Path] = None, out_format: str = None, - print=print, max_poll_interval=60, connection_retry_interval=30, - job_options=None, **format_options) -> BatchJob: + self, + outputfile: Optional[Union[str, pathlib.Path]] = None, + out_format: Optional[str] = None, + *, + print: typing.Callable[[str], None] = print, + max_poll_interval: float = 60, + connection_retry_interval: float = 30, + job_options: Optional[dict] = None, + # TODO: avoid `format_options` as keyword arguments + **format_options, + ) -> BatchJob: """ Evaluate the process graph by creating a batch job, and retrieving the results when it is finished. This method is mostly recommended if the batch job is expected to run in a reasonable amount of time. For very long-running jobs, you probably do not want to keep the client running. - :param job_options: :param outputfile: The path of a file to which a result can be written - :param out_format: (optional) Format of the job result. - :param format_options: String Parameters for the job result format - + :param out_format: (optional) File format to use for the job result. + :param job_options: + :param format_options: output file format parameters. """ if "format" in format_options and not out_format: out_format = format_options["format"] # align with 'download' call arg name - if not out_format: - out_format = guess_format(outputfile) if outputfile else "GTiff" - job = self.create_job(out_format, job_options=job_options, **format_options) + if not out_format and outputfile: + out_format = guess_format(outputfile) + + job = self.create_job( + format=out_format, job_options=job_options, format_options=format_options + ) return job.run_synchronous( outputfile=outputfile, print=print, max_poll_interval=max_poll_interval, connection_retry_interval=connection_retry_interval @@ -1904,6 +1939,7 @@ def create_job( plan: Optional[str] = None, budget: Optional[float] = None, job_options: Optional[dict] = None, + # TODO: avoid `format_options` as keyword arguments **format_options, ) -> BatchJob: """ @@ -1914,22 +1950,19 @@ def create_job( it still needs to be started and tracked explicitly. Use :py:meth:`execute_batch` instead to have the openEO Python client take care of that job management. - :param out_format: String Format of the job result. + :param out_format: output file format. :param title: job title :param description: job description :param plan: billing plan :param budget: maximum cost the request is allowed to produce - :param job_options: A dictionary containing (custom) job options - :param format_options: String Parameters for the job result format + :param job_options: custom job options. + :param format_options: output file format parameters. :return: Created job. """ # TODO: add option to also automatically start the job? # TODO: avoid using all kwargs as format_options # TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ... - cube = self - if out_format: - # add `save_result` node - cube = cube.save_result(format=out_format, options=format_options) + cube = self._ensure_save_result(format=out_format, options=format_options) return self._connection.create_job( process_graph=cube.flat_graph(), title=title, diff --git a/openeo/util.py b/openeo/util.py index b92fdb7d5..031944915 100644 --- a/openeo/util.py +++ b/openeo/util.py @@ -437,7 +437,7 @@ def deep_set(data: dict, *keys, value): raise ValueError("No keys given") -def guess_format(filename: Union[str, Path]): +def guess_format(filename: Union[str, Path]) -> str: """ Guess the output format from a given filename and return the corrected format. Any names not in the dict get passed through. diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index 90db74fea..8524bd733 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -4,16 +4,18 @@ - 1.0.0-style DataCube """ - +import functools from datetime import date, datetime import pathlib +import mock import numpy as np import pytest import shapely import shapely.geometry from openeo.capabilities import ComparableVersion +from openeo.internal.warnings import UserDeprecationWarning from openeo.rest import BandMathException from openeo.rest.datacube import DataCube from .conftest import API_URL @@ -446,3 +448,78 @@ def result_callback(request, context): requests_mock.post(API_URL + '/result', content=result_callback) result = connection.load_collection("S2").download(format=format) assert result == b"data" + + +class TestExecuteBatch: + @pytest.fixture + def get_create_job_pg(self, connection): + """Fixture to help intercepting the process graph that was passed to Connection.create_job""" + with mock.patch.object(connection, "create_job") as create_job: + + def get() -> dict: + assert create_job.call_count == 1 + return create_job.call_args.kwargs["process_graph"] + + yield get + + def test_basic(self, connection, s2cube, get_create_job_pg, recwarn, caplog): + s2cube.execute_batch() + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": "GTiff", + "options": {}, + }, + "result": True, + } + assert recwarn.list == [] + assert caplog.records == [] + + @pytest.mark.parametrize( + ["format", "expected"], + [(None, "GTiff"), ("GTiff", "GTiff"), ("gtiff", "gtiff"), ("NetCDF", "NetCDF")], + ) + def test_format( + self, connection, s2cube, get_create_job_pg, format, expected, recwarn, caplog + ): + s2cube.execute_batch(format=format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + } + assert recwarn.list == [] + assert caplog.records == [] + + @pytest.mark.parametrize( + ["out_format", "expected"], + [("GTiff", "GTiff"), ("NetCDF", "NetCDF")], + ) + def test_out_format( + self, connection, s2cube, get_create_job_pg, out_format, expected + ): + with pytest.warns( + UserDeprecationWarning, + match="`out_format`.*is deprecated.*use `format` instead", + ): + s2cube.execute_batch(out_format=out_format) + pg = get_create_job_pg() + assert set(pg.keys()) == {"loadcollection1", "saveresult1"} + assert pg["saveresult1"] == { + "process_id": "save_result", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "format": expected, + "options": {}, + }, + "result": True, + }