From 0f7f7ff9b3bb21f0045170b17100a6d2861b5c82 Mon Sep 17 00:00:00 2001 From: kbonney Date: Wed, 25 Sep 2024 09:53:48 -0400 Subject: [PATCH] add test for sample_raster --- wntr/tests/test_gis.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/wntr/tests/test_gis.py b/wntr/tests/test_gis.py index 31513f26..12d3f39b 100644 --- a/wntr/tests/test_gis.py +++ b/wntr/tests/test_gis.py @@ -23,6 +23,13 @@ gpd = None has_geopandas = False +try: + import rasterio as rio + has_rasterio = True +except ModuleNotFoundError: + rio = None + has_rasterio = False + testdir = dirname(abspath(str(__file__))) datadir = join(testdir, "networks_for_testing") ex_datadir = join(testdir, "..", "..", "examples", "networks") @@ -70,6 +77,36 @@ def setUpClass(self): df = pd.DataFrame(point_data) self.points = gpd.GeoDataFrame(df, crs=None) + # raster testing + points = [ + (-120.5, 38.5), + (-120.6, 38.6), + (-120.55, 38.65), + (-120.65, 38.55), + (-120.7, 38.7) + ] + point_geometries = [Point(xy) for xy in points] + raster_points = gpd.GeoDataFrame(geometry=point_geometries, crs="EPSG:4326") + raster_points.index = ["A", "B", "C", "D", "E"] + self.raster_points = raster_points + + # create example raster + minx, miny, maxx, maxy = raster_points.total_bounds + raster_width = 100 + raster_height = 100 + + x = np.linspace(0, 1, raster_width) + y = np.linspace(0, 1, raster_height) + raster_data = np.cos(y)[:, np.newaxis] * np.sin(x) # arbitrary values + + transform = rio.transform.from_bounds(minx, miny, maxx, maxy, raster_width, raster_height) + self.transform = transform + + with rio.open( + "test_raster.tif", "w", driver="GTiff", height=raster_height, width=raster_width, + count=1, dtype=raster_data.dtype, crs="EPSG:4326", transform=transform) as dst: + dst.write(raster_data, 1) + @classmethod def tearDownClass(self): pass @@ -311,5 +348,13 @@ def test_snap_points_to_lines(self): assert_frame_equal(pd.DataFrame(snapped_points), expected, check_dtype=False) + def test_sample_raster(self): + raster_values = wntr.gis.sample_raster(self.raster_points, "test_raster.tif", 1) + + assert (raster_values.index == self.raster_points.index).all() + # self.raster_points.plot(column=values, legend=True) + expected_values = np.array([0.000000, 0.423443, 0.665369, 0.174402, 0.000000]) + assert np.isclose(raster_values.values, expected_values, atol=1e-5).all() + if __name__ == "__main__": unittest.main() \ No newline at end of file