Skip to content

Commit

Permalink
Merge pull request #113 from noaa-ocs-modeling/feature/interpolate_ot…
Browse files Browse the repository at this point in the history
…her_bands

Add band argument for interpolation
  • Loading branch information
SorooshMani-NOAA authored Oct 10, 2023
2 parents 0929e5d + b0d9b0c commit d866216
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 12 deletions.
2 changes: 0 additions & 2 deletions ocsmesh/hfun/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,6 @@ def add_region_constraint(
Add fixed-value or fixed-matrix constraint.
add_topo_func_constraint :
Addint constraint based on function of topography
add_courant_num_constraint :
Add constraint based on approximated Courant number
"""

if crs is None:
Expand Down
20 changes: 13 additions & 7 deletions ocsmesh/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def interpolate(
method: Literal['spline', 'linear', 'nearest'] = 'spline',
nprocs: Optional[int] = None,
info_out_path: Union[pathlib.Path, str, None] = None,
filter_by_shape: bool = False
filter_by_shape: bool = False,
band: int = 1,
) -> None:
"""Interplate values from raster inputs to the mesh nodes.
Expand All @@ -359,8 +360,10 @@ def interpolate(
Number of workers to use when interpolating data.
info_out_path : pathlike or str or None
Path for the output node interpolation information file
filter_by_shape : bool
filter_by_shape : bool, default=False
Flag for node filtering based on raster bbox or shape
band : int, default=1
The band from rasters to use for interpolation
Returns
-------
Expand All @@ -382,15 +385,15 @@ def interpolate(
_mesh_interpolate_worker,
[(self.vert2['coord'], self.crs,
_raster.tmpfile, _raster.chunk_size,
method, filter_by_shape)
method, filter_by_shape, band)
for _raster in raster]
)
pool.join()
else:
res = [_mesh_interpolate_worker(
self.vert2['coord'], self.crs,
_raster.tmpfile, _raster.chunk_size,
method, filter_by_shape)
method, filter_by_shape, band)
for _raster in raster]

values = self.msh_t.value.flatten()
Expand Down Expand Up @@ -2234,7 +2237,8 @@ def _mesh_interpolate_worker(
raster_path: Union[str, Path],
chunk_size: Optional[int],
method: Literal['spline', 'linear', 'nearest'] = "spline",
filter_by_shape: bool = False):
filter_by_shape: bool = False,
band: int = 1):
"""Interpolator worker function to be used in parallel calls
Parameters
Expand All @@ -2249,8 +2253,10 @@ def _mesh_interpolate_worker(
Chunk size for windowing over the raster.
method : {'spline', 'linear', 'nearest'}, default='spline'
Method of interpolation.
filter_by_shape : bool
filter_by_shape : bool, default=False
Flag for node filtering based on raster bbox or shape
band : int, default=1
The band from rasters to use for interpolation
Returns
-------
Expand Down Expand Up @@ -2281,7 +2287,7 @@ def _mesh_interpolate_worker(
xi = raster.get_x(window)
yi = raster.get_y(window)
# Use masked array to ignore missing values from DEM
zi = raster.get_values(window=window, masked=True)
zi = raster.get_values(window=window, masked=True, band=band)

if not filter_by_shape:
_idxs = np.logical_and(
Expand Down
15 changes: 13 additions & 2 deletions ocsmesh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,20 +2080,31 @@ def raster_from_numpy(
if not isinstance(crs, CRS):
crs = CRS.from_user_input(crs)

nbands = 1
if data.ndim == 3:
nbands = data.shape[2]
elif data.ndim != 2:
raise ValueError("Invalid data dimensions!")

with rio.open(
filename,
'w',
driver='GTiff',
height=data.shape[0],
width=data.shape[1],
count=1,
count=nbands,
dtype=data.dtype,
crs=crs,
transform=transform,
) as dst:
if isinstance(data, np.ma.MaskedArray):
dst.nodata = data.fill_value
dst.write(data, 1)

data = data.reshape(data.shape[0], data.shape[1], -1)
for i in range(nbands):
dst.write(data.take(i, axis=2), i + 1)




def msht_from_numpy(
Expand Down
65 changes: 64 additions & 1 deletion tests/api/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import tempfile
import unittest
import warnings
import shutil
from pathlib import Path

import numpy as np
from jigsawpy import jigsaw_msh_t
from pyproj import CRS
from shapely import geometry

from ocsmesh import utils
from ocsmesh.mesh.mesh import Mesh
from ocsmesh.mesh.mesh import Mesh, Raster



Expand Down Expand Up @@ -317,5 +319,66 @@ def test_specify_boundary_on_mesh_with_no_boundary(self):
self.assertEqual(bdry.open().iloc[0]['index_id'], [1, 2, 3])


class RasterInterpolation(unittest.TestCase):

def setUp(self):
self.tdir = Path(tempfile.mkdtemp())

msht1 = utils.create_rectangle_mesh(
nx=13, ny=5, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85),
holes=[],
)
msht1.crs = CRS.from_user_input(4326)
msht2 = utils.create_rectangle_mesh(
nx=11, ny=7, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85),
holes=[],
)
msht2.crs = CRS.from_user_input(4326)
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', category=UserWarning,
message='Input mesh has no CRS information'
)
self.mesh1 = Mesh(msht1)
self.mesh2 = Mesh(msht2)

self.rast = self.tdir / 'rast.tif'

rast_xy = np.mgrid[-74:-71:0.1, 40.9:40.5:-0.01]
rast_z = np.ones((rast_xy.shape[1], rast_xy.shape[2], 2))
rast_z[:, :, 1] = 2
utils.raster_from_numpy(
self.rast, rast_z, rast_xy, 4326
)


def tearDown(self):
shutil.rmtree(self.tdir)


def test_interpolation_io(self):
rast = Raster(self.rast)

self.mesh1.interpolate(rast)
self.assertTrue(np.isclose(self.mesh1.value, 1).all())

# TODO: Improve the assertion!
with self.assertRaises(Exception):
self.mesh1.interpolate(self.mesh2)


def test_interpolation_band(self):
rast = Raster(self.rast)

self.mesh1.interpolate(rast)
self.assertTrue(np.isclose(self.mesh1.value, 1).all())

self.mesh1.interpolate(rast, band=2)
self.assertTrue(np.isclose(self.mesh1.value, 2).all())


# TODO Add more interpolation tests


if __name__ == '__main__':
unittest.main()
39 changes: 39 additions & 0 deletions tests/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,45 @@ def test_data_masking(self):
self.assertEqual(rast.src.nodata, fill_value)


def test_multiband_raster_data(self):
nbands = 5
in_data = np.ones((3, 4, nbands))
for i in range(nbands):
in_data[:, :, i] *= i
in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1]
with tempfile.NamedTemporaryFile(suffix='.tiff') as tf:
utils.raster_from_numpy(
tf.name,
data=in_data,
mgrid=in_rast_xy,
crs=4326
)
rast = Raster(tf.name)
self.assertEqual(rast.count, nbands)
for i in range(nbands):
with self.subTest(band_number=i):
self.assertTrue(
(rast.get_values(band=i+1) == i).all()
)


def test_multiband_raster_invalid_io(self):
in_data = np.ones((3, 4, 5, 6))
in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1]
with tempfile.NamedTemporaryFile(suffix='.tiff') as tf:
with self.assertRaises(ValueError) as cm:
utils.raster_from_numpy(
tf.name,
data=in_data,
mgrid=in_rast_xy,
crs=4326
)
exc = cm.exception
self.assertRegex(str(exc).lower(), '.*dimension.*')




class ShapeToMeshT(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit d866216

Please sign in to comment.