From 331dca75632b29f618693fe1da210a125b3f23aa Mon Sep 17 00:00:00 2001 From: Amelie Date: Tue, 4 Jun 2024 11:52:01 +0200 Subject: [PATCH] feat: apply() adapted to dask input --- xdem/coreg/base.py | 35 ++++++++++++++++++++++++++++++++++- xdem/coreg/biascorr.py | 12 ++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index b2230d3f..212a3d20 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -644,6 +644,11 @@ def _preprocess_coreg_fit( return ref_elev, tba_elev, inlier_mask, transform, crs +def mask_array(arr: NDArrayf, nodata: int | float): + """Convert invalid data to nan.""" + return np.where(np.logical_or(~da.isfinite(arr), arr == nodata), np.nan, arr) + + def _preprocess_coreg_apply( elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, transform: rio.transform.Affine | None = None, @@ -651,7 +656,7 @@ def _preprocess_coreg_apply( ) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine, rio.crs.CRS]: """Pre-processing and checks of apply for any input.""" - if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame)): + if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame, DataArray)): raise ValueError("Input elevation data should be a raster, an array or a geodataframe.") # If input is geodataframe @@ -660,6 +665,20 @@ def _preprocess_coreg_apply( new_transform = None new_crs = None + # If input is a Dataarray + elif isinstance(elev, DataArray): + new_transform, new_crs = _select_transform_crs( + transform=transform, + crs=crs, + transform_reference=elev.rio.transform(), + transform_other=None, + crs_reference=elev.rio.crs, + crs_other=None, + ) + + # get the masked elev + elev_out = da.map_blocks(mask_array, elev.data, nodata=elev.rio.nodata, chunks=elev.chunks, dtype=elev.dtype) + # If input is a raster or array else: # If input is raster @@ -698,6 +717,14 @@ def _postprocess_coreg_apply_pts( return applied_elev +def _postprocess_coreg_apply_xarray_xarray(applied_elev: da.Array, out_transform: affine.Affine): + """Post-processing and checks of apply for dask inputs.""" + + # TODO mimic what is happening in the postprocess_coreg_apply_rst. + + return applied_elev, out_transform + + def _postprocess_coreg_apply_rst( elev: NDArrayf | gu.Raster, applied_elev: NDArrayf, @@ -783,6 +810,11 @@ def _postprocess_coreg_apply( resample=resample, resampling=resampling, ) + elif isinstance(applied_elev, da.Array): + applied_elev, out_transform = _postprocess_coreg_apply_xarray_xarray( + applied_elev=applied_elev, out_transform=out_transform + ) + else: applied_elev = _postprocess_coreg_apply_pts(applied_elev) @@ -1480,6 +1512,7 @@ def apply( if self._is_affine: warnings.warn("This coregistration method is affine, ignoring `bias_vars` passed to apply().") + # TODO adapt this for dask for var in bias_vars.keys(): bias_vars[var] = gu.raster.get_array_and_mask(bias_vars[var])[0] diff --git a/xdem/coreg/biascorr.py b/xdem/coreg/biascorr.py index 7d7f41a0..9c8c0fec 100644 --- a/xdem/coreg/biascorr.py +++ b/xdem/coreg/biascorr.py @@ -503,7 +503,10 @@ def _apply_rst( # type: ignore statistic=self._meta["bin_statistic"], ) - dem_corr = elev + corr + if isinstance(elev, da.Array): + dem_corr = da.add(elev, corr) + else: + dem_corr = elev + corr return dem_corr, transform @@ -1184,6 +1187,11 @@ def _apply_rst( ) -> tuple[NDArrayf, rio.transform.Affine]: # Define the coordinates for applying the correction - xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0])) + + if type(elev) == da.Array: + xx = da.map_blocks(meshgrid, elev, chunks=elev.chunks, dtype=elev.dtype) + yy = da.map_blocks(meshgrid, elev, axis="y", chunks=elev.chunks, dtype=elev.dtype) + else: + xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0])) return super()._apply_rst(elev=elev, transform=transform, crs=crs, bias_vars={"xx": xx, "yy": yy}, **kwargs)