From 3089a651e1ab8a2187feeff31c53bcaa9026ea52 Mon Sep 17 00:00:00 2001 From: DirkEilander Date: Thu, 13 Apr 2023 16:35:08 +0200 Subject: [PATCH] Revert "ENH: consistent raster.clip_mask and clip_geom methods (#290)" This reverts commit 4fff76b29c00e06efd9e96b4ecb76019a43ec525. --- hydromt/raster.py | 14 +++++--------- tests/test_raster.py | 2 -- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/hydromt/raster.py b/hydromt/raster.py index 254ac160d..3115bb631 100644 --- a/hydromt/raster.py +++ b/hydromt/raster.py @@ -1047,26 +1047,22 @@ def clip_bbox(self, bbox, align=None, buffer=0, crs=None): return self._obj.sel({self.x_dim: slice(x0, x1), self.y_dim: slice(y0, y1)}) # TODO make consistent with clip_geom - def clip_mask(self, da_mask: xr.DataArray, mask: bool = False): + def clip_mask(self, mask): """Clip object to region with mask values greater than zero. Arguments --------- - da_mask : xarray.DataArray + mask : xarray.DataArray Mask array. - mask: bool, optional - Mask values outside geometry with the raster nodata value Returns ------- xarray.DataSet or DataArray Data clipped to mask """ - if not isinstance(da_mask, xr.DataArray): + if not isinstance(mask, xr.DataArray): raise ValueError("Mask should be xarray.DataArray type.") - if not da_mask.raster.shape == self.shape: + if not mask.raster.shape == self.shape: raise ValueError("Mask shape invalid.") - if mask: - return self._obj.where(da_mask) - mask_bin = (da_mask.values != 0).astype(np.uint8) + mask_bin = (mask.values != 0).astype(np.uint8) if not np.any(mask_bin): raise ValueError("Invalid mask.") row_slice, col_slice = ndimage.find_objects(mask_bin)[0] diff --git a/tests/test_raster.py b/tests/test_raster.py index c080787a2..7aa067c98 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -237,8 +237,6 @@ def test_clip(transform, shape): # test mask da_clip1 = da.raster.clip_mask(da.raster.geometry_mask(gdf)) assert np.all(np.isclose(da_clip1.raster.bounds, da_clip0.raster.bounds)) - da_clip1 = da.raster.clip_mask(da.raster.geometry_mask(gdf), mask=True) - assert np.all(np.isclose(da_clip1.raster.bounds, da.raster.bounds)) # test geom - different crs da_clip1 = da.raster.clip_geom(gdf.to_crs(3857)) assert np.all(np.isclose(da_clip1.raster.bounds, da_clip0.raster.bounds))