From 10401760e1b642d62d4c7d70b48c688ec465abf0 Mon Sep 17 00:00:00 2001 From: Ou Ku Date: Wed, 8 Nov 2023 14:42:17 +0100 Subject: [PATCH] add more tests for stm --- tests/test_stm.py | 58 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index 91a6944..e19c77c 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -7,8 +7,11 @@ import xarray as xr from shapely import geometry +from stmtools.stm import _validate_coords + path_multi_polygon = Path(__file__).parent / "./data/multi_polygon.gpkg" + @pytest.fixture def stmat_rd(): # A STM with rd coordinates @@ -45,6 +48,7 @@ def stmat_only_point(): coords=dict(lon=(["space"], da.arange(npoints)), lat=(["space"], da.arange(npoints))), ).unify_chunks() + @pytest.fixture def stmat_wrong_space_label(): npoints = 10 @@ -110,7 +114,7 @@ def test_time_dim_exists(self, stmat_only_point): def test_time_dim_size_one(self, stmat_only_point): stm_reg = stmat_only_point.stm.regulate_dims() assert stm_reg.dims["time"] == 1 - + def test_time_dim_customed_label(self, stmat_wrong_space_label): stm_reg = stmat_wrong_space_label.stm.regulate_dims(space_label="space2") assert stm_reg.dims["time"] == 1 @@ -128,6 +132,30 @@ def test_subset_works_after_regulate_dims(self, stmat_only_point): stm_reg_subset = stm_reg.stm.subset(method="threshold", var="pnt_height", threshold=">5") assert stm_reg_subset.dims["space"] == 4 + def test_validate_coords(self): + stmat_coords = xr.Dataset( + data_vars=dict( + data=( + ["space", "time"], + da.arange(5 * 10).reshape((10, 5)), + ), + x_coor=(["space"], np.arange(10)), + y_coor=(["space"], np.arange(10)), + ), + coords=dict( + x=(["space"], np.arange(10)), + y=(["space"], np.arange(10)), + time=(["time"], np.arange(5)), + ), + ) + + assert _validate_coords(stmat_coords, 'x', 'y') == 1 + assert _validate_coords(stmat_coords, 'x_coor', 'y_coor') == 2 + assert _validate_coords(stmat_coords, 'x', 'y_coor') == 2 + + with pytest.raises(ValueError): + _validate_coords(stmat_coords, 'x_non', 'y_non') + class TestAttributes: def test_numpoints(self, stmat): @@ -136,6 +164,13 @@ def test_numpoints(self, stmat): def test_numepochss(self, stmat): assert stmat.stm.num_epochs == 5 + def test_register_datatype(self, stmat): + stmat_with_dtype = stmat.stm.register_datatype("pnt_height", "pntAttrib") + assert "pnt_height" in stmat_with_dtype.attrs["pntAttrib"] + + def test_register_datatype_nonexists(self, stmat): + with pytest.raises(ValueError): + stmat.stm.register_datatype("non_exist", "pntAttrib") class TestSubset: def test_check_missing_dimension(self, stmat_only_point): @@ -184,7 +219,7 @@ def test_subset_with_polygons_rd(self, stmat_rd, polygon): def test_subset_with_multi_polygons(self, stmat, multi_polygon): stmat_subset = stmat.stm.subset(method="polygon", polygon=multi_polygon) assert stmat_subset.equals(stmat.sel(space=[2, 6])) - + def test_subset_with_multi_polygons_file(self, stmat): stmat_subset = stmat.stm.subset(method="polygon", polygon=path_multi_polygon) assert stmat_subset.equals(stmat.sel(space=[2, 6])) @@ -228,3 +263,22 @@ def test_enrich_multi_fields_multi_polygon(self, stmat, multi_polygon): results = stmat[field].data.compute() results = [res for res in results if res is not None] assert np.all(results == np.array(multi_polygon[field])) + + def test_enrich_multi_fields_multi_polygon_from_file(self, stmat): + multi_polygon = gpd.read_file(path_multi_polygon) + fields = multi_polygon.columns[0:2] + stmat = stmat.stm.enrich_from_polygon(path_multi_polygon, fields) + for field in fields: + assert field in stmat.data_vars + + results = stmat[field].data.compute() + results = [res for res in results if res is not None] + assert np.all(results == np.array(multi_polygon[field])) + + def test_enrich_exetions(self, stmat, multi_polygon): + with pytest.raises(NotImplementedError): + # int not implemented for polygons + stmat = stmat.stm.enrich_from_polygon(999, "field") + + with pytest.raises(ValueError): + stmat = stmat.stm.enrich_from_polygon(multi_polygon, "non_exist_field")