diff --git a/ocsmesh/utils.py b/ocsmesh/utils.py index 3d8cfa21..685fe14a 100644 --- a/ocsmesh/utils.py +++ b/ocsmesh/utils.py @@ -2157,6 +2157,8 @@ def raster_from_numpy( crs=crs, transform=transform, ) as dst: + if isinstance(data, np.ma.MaskedArray): + dst.nodata = data.fill_value dst.write(data, 1) diff --git a/tests/api/utils.py b/tests/api/utils.py index 075343c7..d555b066 100644 --- a/tests/api/utils.py +++ b/tests/api/utils.py @@ -619,6 +619,39 @@ def test_diff_extent_x_n_y(self): # TODO: Test when x and y extent are different pass + + def test_data_masking(self): + fill_value = 12 + in_rast_xy = np.mgrid[0:1:0.2, 0:1:0.2] + in_rast_z_nomask = np.random.random(in_rast_xy[0].shape) + in_rast_z_mask = np.ma.MaskedArray( + in_rast_z_nomask, + mask=np.random.random(size=in_rast_z_nomask.shape) < 0.5, + fill_value=fill_value + ) + + with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + utils.raster_from_numpy( + tf.name, + data=in_rast_z_nomask, + mgrid=in_rast_xy, + crs=4326 + ) + + rast = Raster(tf.name) + self.assertEqual(rast.src.nodata, None) + + with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + utils.raster_from_numpy( + tf.name, + data=in_rast_z_mask, + mgrid=in_rast_xy, + crs=4326 + ) + + rast = Raster(tf.name) + self.assertEqual(rast.src.nodata, fill_value) + if __name__ == '__main__': unittest.main()