Skip to content

Commit

Permalink
add test for sample_raster
Browse files Browse the repository at this point in the history
  • Loading branch information
kbonney committed Sep 25, 2024
1 parent 2ae78c3 commit 0f7f7ff
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions wntr/tests/test_gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
gpd = None
has_geopandas = False

try:
import rasterio as rio
has_rasterio = True
except ModuleNotFoundError:
rio = None
has_rasterio = False

testdir = dirname(abspath(str(__file__)))
datadir = join(testdir, "networks_for_testing")
ex_datadir = join(testdir, "..", "..", "examples", "networks")
Expand Down Expand Up @@ -70,6 +77,36 @@ def setUpClass(self):
df = pd.DataFrame(point_data)
self.points = gpd.GeoDataFrame(df, crs=None)

# raster testing
points = [
(-120.5, 38.5),
(-120.6, 38.6),
(-120.55, 38.65),
(-120.65, 38.55),
(-120.7, 38.7)
]
point_geometries = [Point(xy) for xy in points]
raster_points = gpd.GeoDataFrame(geometry=point_geometries, crs="EPSG:4326")
raster_points.index = ["A", "B", "C", "D", "E"]
self.raster_points = raster_points

# create example raster
minx, miny, maxx, maxy = raster_points.total_bounds
raster_width = 100
raster_height = 100

x = np.linspace(0, 1, raster_width)
y = np.linspace(0, 1, raster_height)
raster_data = np.cos(y)[:, np.newaxis] * np.sin(x) # arbitrary values

transform = rio.transform.from_bounds(minx, miny, maxx, maxy, raster_width, raster_height)
self.transform = transform

with rio.open(
"test_raster.tif", "w", driver="GTiff", height=raster_height, width=raster_width,
count=1, dtype=raster_data.dtype, crs="EPSG:4326", transform=transform) as dst:
dst.write(raster_data, 1)

@classmethod
def tearDownClass(self):
pass
Expand Down Expand Up @@ -311,5 +348,13 @@ def test_snap_points_to_lines(self):

assert_frame_equal(pd.DataFrame(snapped_points), expected, check_dtype=False)

def test_sample_raster(self):
raster_values = wntr.gis.sample_raster(self.raster_points, "test_raster.tif", 1)

assert (raster_values.index == self.raster_points.index).all()
# self.raster_points.plot(column=values, legend=True)
expected_values = np.array([0.000000, 0.423443, 0.665369, 0.174402, 0.000000])
assert np.isclose(raster_values.values, expected_values, atol=1e-5).all()

if __name__ == "__main__":
unittest.main()

0 comments on commit 0f7f7ff

Please sign in to comment.