Skip to content

Commit

Permalink
feat: apply() adapted to dask input
Browse files Browse the repository at this point in the history
  • Loading branch information
ameliefroessl committed Jun 4, 2024
1 parent fc30046 commit 331dca7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
35 changes: 34 additions & 1 deletion xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,19 @@ def _preprocess_coreg_fit(
return ref_elev, tba_elev, inlier_mask, transform, crs


def mask_array(arr: NDArrayf, nodata: int | float):
"""Convert invalid data to nan."""
return np.where(np.logical_or(~da.isfinite(arr), arr == nodata), np.nan, arr)


def _preprocess_coreg_apply(
elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame,
transform: rio.transform.Affine | None = None,
crs: rio.crs.CRS | None = None,
) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine, rio.crs.CRS]:
"""Pre-processing and checks of apply for any input."""

if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame)):
if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame, DataArray)):
raise ValueError("Input elevation data should be a raster, an array or a geodataframe.")

# If input is geodataframe
Expand All @@ -660,6 +665,20 @@ def _preprocess_coreg_apply(
new_transform = None
new_crs = None

# If input is a Dataarray
elif isinstance(elev, DataArray):
new_transform, new_crs = _select_transform_crs(
transform=transform,
crs=crs,
transform_reference=elev.rio.transform(),
transform_other=None,
crs_reference=elev.rio.crs,
crs_other=None,
)

# get the masked elev
elev_out = da.map_blocks(mask_array, elev.data, nodata=elev.rio.nodata, chunks=elev.chunks, dtype=elev.dtype)

# If input is a raster or array
else:
# If input is raster
Expand Down Expand Up @@ -698,6 +717,14 @@ def _postprocess_coreg_apply_pts(
return applied_elev


def _postprocess_coreg_apply_xarray_xarray(applied_elev: da.Array, out_transform: affine.Affine):
"""Post-processing and checks of apply for dask inputs."""

# TODO mimic what is happening in the postprocess_coreg_apply_rst.

return applied_elev, out_transform


def _postprocess_coreg_apply_rst(
elev: NDArrayf | gu.Raster,
applied_elev: NDArrayf,
Expand Down Expand Up @@ -783,6 +810,11 @@ def _postprocess_coreg_apply(
resample=resample,
resampling=resampling,
)
elif isinstance(applied_elev, da.Array):
applied_elev, out_transform = _postprocess_coreg_apply_xarray_xarray(
applied_elev=applied_elev, out_transform=out_transform
)

else:
applied_elev = _postprocess_coreg_apply_pts(applied_elev)

Expand Down Expand Up @@ -1480,6 +1512,7 @@ def apply(
if self._is_affine:
warnings.warn("This coregistration method is affine, ignoring `bias_vars` passed to apply().")

# TODO adapt this for dask
for var in bias_vars.keys():
bias_vars[var] = gu.raster.get_array_and_mask(bias_vars[var])[0]

Expand Down
12 changes: 10 additions & 2 deletions xdem/coreg/biascorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ def _apply_rst( # type: ignore
statistic=self._meta["bin_statistic"],
)

dem_corr = elev + corr
if isinstance(elev, da.Array):
dem_corr = da.add(elev, corr)
else:
dem_corr = elev + corr

return dem_corr, transform

Expand Down Expand Up @@ -1184,6 +1187,11 @@ def _apply_rst(
) -> tuple[NDArrayf, rio.transform.Affine]:

# Define the coordinates for applying the correction
xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0]))

if type(elev) == da.Array:
xx = da.map_blocks(meshgrid, elev, chunks=elev.chunks, dtype=elev.dtype)
yy = da.map_blocks(meshgrid, elev, axis="y", chunks=elev.chunks, dtype=elev.dtype)
else:
xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0]))

return super()._apply_rst(elev=elev, transform=transform, crs=crs, bias_vars={"xx": xx, "yy": yy}, **kwargs)

0 comments on commit 331dca7

Please sign in to comment.