From 36489cb01be0c797a98a3579d53538159c795ae9 Mon Sep 17 00:00:00 2001 From: snowman2 Date: Mon, 1 Apr 2024 14:47:18 -0500 Subject: [PATCH] TYPE: Standardize resampling type --- datacube/api/core.py | 9 +++++---- datacube/storage/_load.py | 3 ++- datacube/utils/cog.py | 10 +++++++--- docs/about/whats_new.rst | 3 +++ tests/storage/test_storage_read.py | 28 ++++++++++++++++++++-------- 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/datacube/api/core.py b/datacube/api/core.py index 322868d25..eb3201d3c 100644 --- a/datacube/api/core.py +++ b/datacube/api/core.py @@ -18,6 +18,7 @@ from datacube.storage import reproject_and_fuse, BandInfo from datacube.utils import ignore_exceptions_if from odc.geo import CRS, yx_, res_, resyx_, Resolution, XY +from odc.geo.warp import Resampling from odc.geo.xr import xr_coords from datacube.utils.dates import normalise_dt from odc.geo.geom import intersects, box, bbox_union, Geometry @@ -244,7 +245,7 @@ def load(self, measurements: str | list[str] | None = None, output_crs: Any = None, resolution: int | float | tuple[int | float, int | float] | Resolution | None = None, - resampling: str | dict[str, str] | None = None, + resampling: Resampling | dict[str, Resampling] | None = None, align: XY[float] | Iterable[float] | None = None, skip_broken_datasets: bool = False, dask_chunks: dict[str, str | int] | None = None, @@ -878,7 +879,7 @@ def _cbk(*ignored): @staticmethod def load_data(sources: xarray.DataArray, geobox: GeoBox, measurements: Mapping[str, Measurement] | list[Measurement], - resampling: str | dict[str, str] | None = None, + resampling: Resampling | dict[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None, dask_chunks: dict[str, str | int] | None = None, skip_broken_datasets: bool = False, @@ -969,7 +970,7 @@ def __exit__(self, type_, value, traceback): def per_band_load_data_settings(measurements: list[Measurement] | Mapping[str, Measurement], - resampling: str | Mapping[str, str] | None = None, + resampling: Resampling | Mapping[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None ) -> list[Measurement]: def with_resampling(m, resampling, default=None): @@ -982,7 +983,7 @@ def with_fuser(m, fuser, default=None): m['fuser'] = fuser.get(m.name, default) return m - if isinstance(resampling, str): + if resampling is not None and not isinstance(resampling, dict): resampling = {'*': resampling} if fuse_func is None or callable(fuse_func): diff --git a/datacube/storage/_load.py b/datacube/storage/_load.py index ed0b8f5f0..6c25666dc 100644 --- a/datacube/storage/_load.py +++ b/datacube/storage/_load.py @@ -23,6 +23,7 @@ from odc.geo.geobox import GeoBox from odc.geo.roi import roi_is_empty from odc.geo.xr import xr_coords +from odc.geo.warp import Resampling from datacube.model import Measurement from datacube.drivers._types import ReaderDriver from ..drivers.datasource import DataSource @@ -47,7 +48,7 @@ def reproject_and_fuse(datasources: List[DataSource], destination: np.ndarray, dst_geobox: GeoBox, dst_nodata: Optional[Union[int, float]], - resampling: str = 'nearest', + resampling: Resampling = 'nearest', fuse_func: Optional[FuserFunction] = None, skip_broken_datasets: bool = False, progress_cbk: Optional[ProgressFunction] = None, diff --git a/datacube/utils/cog.py b/datacube/utils/cog.py index 460ebc40c..9b4654a5f 100644 --- a/datacube/utils/cog.py +++ b/datacube/utils/cog.py @@ -18,6 +18,7 @@ from .io import check_write_path from odc.geo.geobox import GeoBox from odc.geo.math import align_up +from odc.geo.warp import Resampling, resampling_s2rio from deprecat import deprecat @@ -38,7 +39,7 @@ def _write_cog( nodata: Optional[float] = None, overwrite: bool = False, blocksize: Optional[int] = None, - overview_resampling: Optional[str] = None, + overview_resampling: Optional[Resampling] = None, overview_levels: Optional[List[int]] = None, ovr_blocksize: Optional[int] = None, use_windowed_writes: bool = False, @@ -118,7 +119,10 @@ def _write_cog( fname, overwrite ) # aborts if overwrite=False and file exists already - resampling = rasterio.enums.Resampling[overview_resampling] + if isinstance(overview_resampling, str): + resampling = resampling_s2rio(overview_resampling) + else: + resampling = overview_resampling if (blocksize % 16) != 0: warnings.warn("Block size must be a multiple of 16, will be adjusted") @@ -219,7 +223,7 @@ def write_cog( overwrite: bool = False, blocksize: Optional[int] = None, ovr_blocksize: Optional[int] = None, - overview_resampling: Optional[str] = None, + overview_resampling: Optional[Resampling] = None, overview_levels: Optional[List[int]] = None, use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, diff --git a/docs/about/whats_new.rst b/docs/about/whats_new.rst index 7c06c0ad0..afd67980d 100644 --- a/docs/about/whats_new.rst +++ b/docs/about/whats_new.rst @@ -8,6 +8,9 @@ What's New v1.9.next ========= +- Standardize resampling input supported to `odc.geo.warp.Resampling`. + + v1.9.0-rc3 (27th March 2024) ============================ diff --git a/tests/storage/test_storage_read.py b/tests/storage/test_storage_read.py index 8977c4cb3..a9ccb4d29 100644 --- a/tests/storage/test_storage_read.py +++ b/tests/storage/test_storage_read.py @@ -4,6 +4,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np +import pytest +from rasterio.enums import Resampling + from datacube.storage._read import ( read_time_slice, read_time_slice_v2, @@ -27,6 +30,11 @@ ) +nearest_resampling_parametrize = pytest.mark.parametrize( + "nearest_resampling", ['nearest', Resampling.nearest, Resampling.nearest.value] +) + + def test_pick_read_scale(): assert pick_read_scale(0.7) == 1 assert pick_read_scale(1.3) == 1 @@ -34,7 +42,8 @@ def test_pick_read_scale(): assert pick_read_scale(1.99999) == 2 -def test_read_paste(tmpdir): +@nearest_resampling_parametrize +def test_read_paste(nearest_resampling, tmpdir): from datacube.testutils import mk_test_image from datacube.testutils.io import write_gtiff from pathlib import Path @@ -46,7 +55,7 @@ def test_read_paste(tmpdir): mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None) - def _read(geobox, resampling='nearest', + def _read(geobox, resampling=nearest_resampling, fallback_nodata=-999, dst_nodata=-999, check_paste=False): @@ -112,7 +121,8 @@ def _read(geobox, resampling='nearest', np.testing.assert_array_equal(xx[1::2, 1::2], yy) -def test_read_with_reproject(tmpdir): +@nearest_resampling_parametrize +def test_read_with_reproject(nearest_resampling, tmpdir): from datacube.testutils import mk_test_image from datacube.testutils.io import write_gtiff from pathlib import Path @@ -131,7 +141,7 @@ def test_read_with_reproject(tmpdir): assert mm.geobox == tile def _read(geobox, - resampling='nearest', + resampling=nearest_resampling, fallback_nodata=None, dst_nodata=-999): with RasterFileDataSource(mm.path, 1, nodata=fallback_nodata).open() as rdr: @@ -171,7 +181,8 @@ def _read(geobox, assert nvalid > nempty -def test_read_paste_v2(tmpdir): +@nearest_resampling_parametrize +def test_read_paste_v2(nearest_resampling, tmpdir): from datacube.testutils import mk_test_image from datacube.testutils.io import write_gtiff from datacube.testutils.iodriver import open_reader @@ -184,7 +195,7 @@ def test_read_paste_v2(tmpdir): mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None) - def _read(geobox, resampling='nearest', + def _read(geobox, resampling=nearest_resampling, fallback_nodata=-999, dst_nodata=-999, check_paste=False): @@ -256,7 +267,8 @@ def _read(geobox, resampling='nearest', np.testing.assert_array_equal(xx[1::2, 1::2], yy) -def test_read_with_reproject_v2(tmpdir): +@nearest_resampling_parametrize +def test_read_with_reproject_v2(nearest_resampling, tmpdir): from datacube.testutils import mk_test_image from datacube.testutils.io import write_gtiff from datacube.testutils.iodriver import open_reader @@ -268,7 +280,7 @@ def test_read_with_reproject_v2(tmpdir): assert (xx != -999).all() tile = AlbersGS.tile_geobox((17, -40))[:64, :128] - def _read(geobox, resampling='nearest', + def _read(geobox, resampling=nearest_resampling, fallback_nodata=-999, dst_nodata=-999):