From 4efa368741da6bda496e443aada0f8c4d817a918 Mon Sep 17 00:00:00 2001 From: Martin Vonk Date: Wed, 18 Sep 2024 10:15:41 +0200 Subject: [PATCH] add some more tests to up coverage --- tests/test_si.py | 75 +++++++++++++++++++++++++++++++++++++++++- tests/test_validate.py | 35 ++++++++++++++++++-- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/tests/test_si.py b/tests/test_si.py index 6e75ab6..c7b5fac 100644 --- a/tests/test_si.py +++ b/tests/test_si.py @@ -1,6 +1,7 @@ -from pandas import Series, Timestamp +from pandas import DataFrame, Series, Timestamp from scipy.stats import norm from spei import SI, sgi, spei, spi, ssfi +from spei.dist import Dist def test_spi(prec: Series) -> None: @@ -32,3 +33,75 @@ def test_SI(prec: Series) -> None: si.pdf() dist = si.get_dist(Timestamp("2010-01-01")) dist.ks_test() + + +def test_SI_post_init_timescale(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=30, fit_freq="ME") + assert si.series.equals( + prec.rolling(30, min_periods=30).sum().dropna() + ), "Timescale rolling sum not applied correctly" + + +def test_SI_post_init_fit_freq_infer(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=0) + assert si.fit_freq is not None, "Frequency inference failed" + + +def test_SI_post_init_grouped_year(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=0, fit_freq="ME") + assert isinstance(si._grouped_year, DataFrame), "Grouped year DataFrame not created" + + +def test_SI_post_init_fit_window_adjustment(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=0, fit_freq="D", fit_window=2) + assert si.fit_window == 3, "Fit window not adjusted to odd number" + + +def test_SI_post_init_fit_window_minimum(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=0, fit_freq="D", fit_window=1) + assert si.fit_window == 3, "Fit window not adjusted to minimum value" + + +def test_fit_distribution_normal_scores_transform(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=30, fit_freq="ME", normal_scores_transform=True) + si.fit_distribution() + assert ( + not si._dist_dict + ), "Distribution dictionary should be empty when using normal scores transform" + + +def test_fit_distribution_with_fit_window(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=30, fit_freq="D", fit_window=5) + si.fit_distribution() + assert ( + si._dist_dict + ), "Distribution dictionary should not be empty when using fit window" + for dist in si._dist_dict.values(): + assert isinstance( + dist, Dist + ), "Items in distribution dictionary should be of type Dist" + + +def test_fit_distribution_with_fit_freq(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=30, fit_freq="M") + si.fit_distribution() + assert ( + si._dist_dict + ), "Distribution dictionary should not be empty when using fit frequency" + for dist in si._dist_dict.values(): + assert isinstance( + dist, Dist + ), "Items in distribution dictionary should be of type Dist" + + +def test_fit_distribution_invalid_fit_freq_with_window(prec: Series) -> None: + si = SI(prec, dist=norm, timescale=30, fit_freq="M", fit_window=5) + try: + si.fit_distribution() + except ValueError as e: + assert ( + str(e) + == "Frequency fit_freq must be 'D' or 'W', not 'M', if a fit_window is provided." + ) + else: + assert False, "ValueError not raised for invalid fit frequency with fit window" diff --git a/tests/test_validate.py b/tests/test_validate.py index f459336..4a49482 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,8 +1,8 @@ import logging import pytest -from pandas import DataFrame, DatetimeIndex, Series, Timestamp, to_datetime -from spei.utils import validate_index, validate_series +from pandas import DataFrame, DatetimeIndex, Index, Series, Timestamp, to_datetime +from spei.utils import infer_frequency, validate_index, validate_series def test_validate_index(caplog) -> None: @@ -50,3 +50,34 @@ def test_validate_series_df_2d() -> None: with pytest.raises(TypeError): df = DataFrame({"s1": [1, 2, 3], "s2": [1, 2, 3]}, index=to_datetime([1, 2, 3])) validate_series(df) + + +def test_infer_frequency_monthly(): + index = DatetimeIndex(["2020-01-01", "2020-02-01", "2020-03-01"]) + assert infer_frequency(index) == "M" + + +def test_infer_frequency_weekly(): + index = DatetimeIndex(["2020-01-01", "2020-01-08", "2020-01-15"]) + assert infer_frequency(index) == "W" + + +def test_infer_frequency_daily(): + index = DatetimeIndex(["2020-01-01", "2020-01-02", "2020-01-03"]) + assert infer_frequency(index) == "D" + + +def test_infer_frequency_no_infer(): + index = DatetimeIndex(["2020-01-01", "2020-01-03", "2020-01-07"]) + assert infer_frequency(index) == "ME" # Assuming pandas version >= 2.2.0 + + +def test_infer_frequency_non_datetime_index(): + index = Index(["2020-01-01", "2020-02-01", "2020-03-01"]) + assert infer_frequency(index) == "M" + + +def test_infer_frequency_invalid_index(): + index = Index(["a", "b", "c"]) + with pytest.raises(ValueError): + infer_frequency(index)