Skip to content

Commit

Permalink
ENH: consistent raster.clip_mask and clip_geom methods (#290)
Browse files Browse the repository at this point in the history
* Made arguments consistent between raster.clip_geom and raster.clip_mask and updated tests (#284)

* Added black formatting
  • Loading branch information
Tjalling-dejong authored Apr 6, 2023
1 parent 4266d3f commit 4fff76b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 9 additions & 5 deletions hydromt/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,22 +1047,26 @@ def clip_bbox(self, bbox, align=None, buffer=0, crs=None):
return self._obj.sel({self.x_dim: slice(x0, x1), self.y_dim: slice(y0, y1)})

# TODO make consistent with clip_geom
def clip_mask(self, mask):
def clip_mask(self, da_mask: xr.DataArray, mask: bool = False):
"""Clip object to region with mask values greater than zero.
Arguments
---------
mask : xarray.DataArray
da_mask : xarray.DataArray
Mask array.
mask: bool, optional
Mask values outside geometry with the raster nodata value
Returns
-------
xarray.DataSet or DataArray
Data clipped to mask
"""
if not isinstance(mask, xr.DataArray):
if not isinstance(da_mask, xr.DataArray):
raise ValueError("Mask should be xarray.DataArray type.")
if not mask.raster.shape == self.shape:
if not da_mask.raster.shape == self.shape:
raise ValueError("Mask shape invalid.")
mask_bin = (mask.values != 0).astype(np.uint8)
if mask:
return self._obj.where(da_mask)
mask_bin = (da_mask.values != 0).astype(np.uint8)
if not np.any(mask_bin):
raise ValueError("Invalid mask.")
row_slice, col_slice = ndimage.find_objects(mask_bin)[0]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def test_clip(transform, shape):
# test mask
da_clip1 = da.raster.clip_mask(da.raster.geometry_mask(gdf))
assert np.all(np.isclose(da_clip1.raster.bounds, da_clip0.raster.bounds))
da_clip1 = da.raster.clip_mask(da.raster.geometry_mask(gdf), mask=True)
assert np.all(np.isclose(da_clip1.raster.bounds, da.raster.bounds))
# test geom - different crs
da_clip1 = da.raster.clip_geom(gdf.to_crs(3857))
assert np.all(np.isclose(da_clip1.raster.bounds, da_clip0.raster.bounds))
Expand Down

0 comments on commit 4fff76b

Please sign in to comment.