Skip to content

Commit

Permalink
add more tests for stm
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerkuou committed Nov 8, 2023
1 parent cd071a1 commit 1040176
Showing 1 changed file with 56 additions and 2 deletions.
58 changes: 56 additions & 2 deletions tests/test_stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import xarray as xr
from shapely import geometry

from stmtools.stm import _validate_coords

path_multi_polygon = Path(__file__).parent / "./data/multi_polygon.gpkg"


@pytest.fixture
def stmat_rd():
# A STM with rd coordinates
Expand Down Expand Up @@ -45,6 +48,7 @@ def stmat_only_point():
coords=dict(lon=(["space"], da.arange(npoints)), lat=(["space"], da.arange(npoints))),
).unify_chunks()


@pytest.fixture
def stmat_wrong_space_label():
npoints = 10
Expand Down Expand Up @@ -110,7 +114,7 @@ def test_time_dim_exists(self, stmat_only_point):
def test_time_dim_size_one(self, stmat_only_point):
stm_reg = stmat_only_point.stm.regulate_dims()
assert stm_reg.dims["time"] == 1

def test_time_dim_customed_label(self, stmat_wrong_space_label):
stm_reg = stmat_wrong_space_label.stm.regulate_dims(space_label="space2")
assert stm_reg.dims["time"] == 1
Expand All @@ -128,6 +132,30 @@ def test_subset_works_after_regulate_dims(self, stmat_only_point):
stm_reg_subset = stm_reg.stm.subset(method="threshold", var="pnt_height", threshold=">5")
assert stm_reg_subset.dims["space"] == 4

def test_validate_coords(self):
stmat_coords = xr.Dataset(
data_vars=dict(
data=(
["space", "time"],
da.arange(5 * 10).reshape((10, 5)),
),
x_coor=(["space"], np.arange(10)),
y_coor=(["space"], np.arange(10)),
),
coords=dict(
x=(["space"], np.arange(10)),
y=(["space"], np.arange(10)),
time=(["time"], np.arange(5)),
),
)

assert _validate_coords(stmat_coords, 'x', 'y') == 1
assert _validate_coords(stmat_coords, 'x_coor', 'y_coor') == 2
assert _validate_coords(stmat_coords, 'x', 'y_coor') == 2

with pytest.raises(ValueError):
_validate_coords(stmat_coords, 'x_non', 'y_non')


class TestAttributes:
def test_numpoints(self, stmat):
Expand All @@ -136,6 +164,13 @@ def test_numpoints(self, stmat):
def test_numepochss(self, stmat):
assert stmat.stm.num_epochs == 5

def test_register_datatype(self, stmat):
stmat_with_dtype = stmat.stm.register_datatype("pnt_height", "pntAttrib")
assert "pnt_height" in stmat_with_dtype.attrs["pntAttrib"]

def test_register_datatype_nonexists(self, stmat):
with pytest.raises(ValueError):
stmat.stm.register_datatype("non_exist", "pntAttrib")

class TestSubset:
def test_check_missing_dimension(self, stmat_only_point):
Expand Down Expand Up @@ -184,7 +219,7 @@ def test_subset_with_polygons_rd(self, stmat_rd, polygon):
def test_subset_with_multi_polygons(self, stmat, multi_polygon):
stmat_subset = stmat.stm.subset(method="polygon", polygon=multi_polygon)
assert stmat_subset.equals(stmat.sel(space=[2, 6]))

def test_subset_with_multi_polygons_file(self, stmat):
stmat_subset = stmat.stm.subset(method="polygon", polygon=path_multi_polygon)
assert stmat_subset.equals(stmat.sel(space=[2, 6]))
Expand Down Expand Up @@ -228,3 +263,22 @@ def test_enrich_multi_fields_multi_polygon(self, stmat, multi_polygon):
results = stmat[field].data.compute()
results = [res for res in results if res is not None]
assert np.all(results == np.array(multi_polygon[field]))

def test_enrich_multi_fields_multi_polygon_from_file(self, stmat):
multi_polygon = gpd.read_file(path_multi_polygon)
fields = multi_polygon.columns[0:2]
stmat = stmat.stm.enrich_from_polygon(path_multi_polygon, fields)
for field in fields:
assert field in stmat.data_vars

results = stmat[field].data.compute()
results = [res for res in results if res is not None]
assert np.all(results == np.array(multi_polygon[field]))

def test_enrich_exetions(self, stmat, multi_polygon):
with pytest.raises(NotImplementedError):
# int not implemented for polygons
stmat = stmat.stm.enrich_from_polygon(999, "field")

with pytest.raises(ValueError):
stmat = stmat.stm.enrich_from_polygon(multi_polygon, "non_exist_field")

0 comments on commit 1040176

Please sign in to comment.