Skip to content

Commit

Permalink
Add apply_z_correction Parameter to warp_dem Function (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn authored Dec 16, 2024
1 parent 55bb902 commit 524311d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
31 changes: 31 additions & 0 deletions tests/test_coreg/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,37 @@ def test_warp_dem() -> None:
# Due to the randomness, the threshold is quite high, but would be something like 10+ if it was incorrect.
assert spatialstats.nmad(dem - untransformed_dem) < 0.5

# Test with Z-correction disabled
transformed_dem_no_z = coreg.base.warp_dem(
dem=dem,
transform=transform,
source_coords=source_coords,
destination_coords=dest_coords,
resampling="linear",
apply_z_correction=False,
)

# Try to undo the warp by reversing the source-destination coordinates with Z-correction disabled
untransformed_dem_no_z = coreg.base.warp_dem(
dem=transformed_dem_no_z,
transform=transform,
source_coords=dest_coords,
destination_coords=source_coords,
resampling="linear",
apply_z_correction=False,
)

# Validate that the DEM is now more or less the same as the original, with Z-correction disabled.
# The result should be similar to the original, but with no Z-shift applied.
assert spatialstats.nmad(dem - untransformed_dem_no_z) < 0.5

# The difference between the two DEMs should be the vertical shift.
# We expect the difference to be approximately equal to the average vertical shift.
expected_vshift = np.mean(dest_coords[:, 2] - source_coords[:, 2])

# Check that the mean difference between the DEMs matches the expected vertical shift.
assert np.nanmean(transformed_dem_no_z - transformed_dem) == pytest.approx(expected_vshift, rel=0.3)

if False:
import matplotlib.pyplot as plt

Expand Down
10 changes: 8 additions & 2 deletions xdem/coreg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,7 @@ def __init__(
success_threshold: float = 0.8,
n_threads: int | None = None,
warn_failures: bool = False,
apply_z_correction: bool = True,
) -> None:
"""
Instantiate a blockwise processing object.
Expand All @@ -3022,6 +3023,7 @@ def __init__(
:param success_threshold: Raise an error if fewer chunks than the fraction failed for any reason.
:param n_threads: The maximum amount of threads to use. Default=auto
:param warn_failures: Trigger or ignore warnings for each exception/warning in each block.
:param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True).
"""
if isinstance(step, type):
raise ValueError(
Expand All @@ -3032,6 +3034,7 @@ def __init__(
self.success_threshold = success_threshold
self.n_threads = n_threads
self.warn_failures = warn_failures
self.apply_z_correction = apply_z_correction

super().__init__()

Expand Down Expand Up @@ -3396,6 +3399,7 @@ def _apply_rst(
source_coords=all_points[:, :, 1],
destination_coords=all_points[:, :, 0],
resampling="linear",
apply_z_correction=self.apply_z_correction,
)

return warped_dem, transform
Expand Down Expand Up @@ -3436,6 +3440,7 @@ def warp_dem(
resampling: str = "cubic",
trim_border: bool = True,
dilate_mask: bool = True,
apply_z_correction: bool = True,
) -> NDArrayf:
"""
(22/08/24: Method currently used only for blockwise coregistration)
Expand All @@ -3448,6 +3453,7 @@ def warp_dem(
:param resampling: The resampling order to use. Choices: ['nearest', 'linear', 'cubic'].
:param trim_border: Remove values outside of the interpolation regime (True) or leave them unmodified (False).
:param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong.
:param apply_z_correction: Boolean to toggle whether the Z-offset correction is applied or not (default True).
:raises ValueError: If the inputs are poorly formatted.
:raises AssertionError: For unexpected outputs.
Expand Down Expand Up @@ -3534,8 +3540,8 @@ def warp_dem(

warped[new_mask] = np.nan

# If the coordinates are 3D (N, 3), apply a Z correction as well.
if not no_vertical:
# Apply the Z-correction if apply_z_correction is True and if the coordinates are 3D (N, 3)
if not no_vertical and apply_z_correction:
grid_offsets = scipy.interpolate.griddata(
points=destination_coords_scaled[:, :2],
values=source_coords_scaled[:, 2] - destination_coords_scaled[:, 2],
Expand Down

0 comments on commit 524311d

Please sign in to comment.