diff --git a/tests/conftest.py b/tests/conftest.py index b3ef78e..eaf9895 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ def path_m3924_excel(): def path_example_prepped_csv(): return os.path.join(TEST_DIR, "files/example_prepped.csv") + @pytest.fixture def path_m3924_prepped_csv(): return os.path.join(TEST_DIR, "files/test-m3924.csv") @@ -37,10 +38,12 @@ def path_m3924_prepped_csv(): def path_example_forecast_csv(): return os.path.join(TEST_DIR, "files/example_forecast.csv") + @pytest.fixture def path_forecast_csv_m3924(): return os.path.join(TEST_DIR, "files/test-m3924-forecast.csv") + @pytest.fixture def path_example_img(): return os.path.join(TEST_DIR, "tmp/path_example_img.png") @@ -51,11 +54,13 @@ def example_forecast_df(path_example_forecast_csv): df = load_file(path_example_forecast_csv, file_type="csv") return df + @pytest.fixture def m3924_forecast_df(path_forecast_csv_m3924): df = load_file(path_forecast_csv_m3924, file_type="csv") return df + @pytest.fixture def example_forecast_df_baseline(path_example_forecast_csv): baseline_days = 7.0 diff --git a/tests/test_forecast.py b/tests/test_forecast.py index 4a7af22..13089b4 100644 --- a/tests/test_forecast.py +++ b/tests/test_forecast.py @@ -52,6 +52,7 @@ def test_arima_cli(capsys, monkeypatch): print(captured) assert os.path.exists(output_path), "Output file was not created" + def test_arima_m3924(path_m3924_prepped_csv): model_kwargs = { "order": (0, 0, 10), diff --git a/tests/test_stats.py b/tests/test_stats.py index 14d5426..a21a898 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -2,9 +2,12 @@ import polars as pl from vaxstats.analysis.forecast import detect_fever_hypothermia, run_analysis -from vaxstats.analysis.hourly import add_hourly_thresholds, calculate_hourly_stats from vaxstats.analysis.residual import add_residuals_col, get_residual_bounds from vaxstats.analysis.stats import get_column_stat, get_column_stats +from vaxstats.analysis.timeframe import ( + add_hourly_thresholds, + calculate_stats_by_timeframe, +) from vaxstats.utils import get_baseline_df, str_to_datetime @@ -43,13 +46,15 @@ def test_residual_bounds(example_forecast_df_baseline): ) -def test_calculate_hourly_stats(example_forecast_df_baseline): +def test_calculate_stats_by_hour(example_forecast_df_baseline): df = example_forecast_df_baseline - hourly_stats = calculate_hourly_stats( + hourly_stats = calculate_stats_by_timeframe( df, + timeframe="hour", data_column="y", pred_column="y_hat", date_column="ds", + start_from_first=False, ) assert "y_median" in hourly_stats.columns @@ -75,6 +80,37 @@ def test_calculate_hourly_stats(example_forecast_df_baseline): assert np.allclose(temps_hourly_median[-1], np.array([38.8151678])) +def test_calculate_stats_by_day(example_forecast_df_baseline): + df = example_forecast_df_baseline + stats = calculate_stats_by_timeframe( + df, + timeframe="day", + data_column="y", + pred_column="y_hat", + date_column="ds", + start_from_first=True, + ) + + assert "y_median" in stats.columns + assert "y_hat_median" in stats.columns + assert "start_time" in stats.columns + assert "end_time" in stats.columns + assert "data_points" in stats.columns + + # Check if start_time is always less than or equal to end_time + assert (stats["start_time"] <= stats["end_time"]).all() + + # Check if data_points is always positive + assert (stats["data_points"] > 0).all() + + temps_median = stats["y_median"].to_numpy() + assert np.allclose( + temps_median[:4], + np.array([37.635, 37.707, 37.62, 37.501]), + ) + assert np.allclose(temps_median[-1], np.array([37.2575])) + + def test_calculate_thresholds(example_forecast_df): df = example_forecast_df df = add_residuals_col(df) @@ -86,8 +122,8 @@ def test_calculate_thresholds(example_forecast_df): df_baseline = get_baseline_df(df, baseline=baseline_hours) residual_lower, residual_upper = get_residual_bounds(df_baseline) - hourly_stats = calculate_hourly_stats( - df, data_column="y", pred_column="y_hat", date_column="ds" + hourly_stats = calculate_stats_by_timeframe( + df, timeframe="hour", data_column="y", pred_column="y_hat", date_column="ds" ) hourly_stats = add_hourly_thresholds(hourly_stats, residual_lower, residual_upper) @@ -99,10 +135,10 @@ def test_calculate_thresholds(example_forecast_df): hypo_threshold = hourly_stats["hypo_threshold"].to_numpy() assert np.allclose( - fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221]) + fever_threshold[:3], np.array([39.09025125, 39.22112064, 38.65119125]) ) assert np.allclose( - hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112]) + hypo_threshold[:3], np.array([38.56102016, 38.69188955, 38.12196017]) ) @@ -127,10 +163,10 @@ def test_detect_fever_hypothermia(example_forecast_df): hypo_threshold = hourly_stats["hypo_threshold"].to_numpy() assert np.allclose( - fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221]) + fever_threshold[:3], np.array([39.09025125, 39.22112064, 38.65119125]) ) assert np.allclose( - hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112]) + hypo_threshold[:3], np.array([38.56102016, 38.69188955, 38.12196017]) ) @@ -156,8 +192,9 @@ def test_get_all_stats(example_forecast_df): assert np.allclose(results["residual"]["max_residual"], 2.70556) assert np.allclose(results["residual"]["residual_upper_bound"], 0.264615542) assert np.allclose(results["duration"]["total_duration_hours"], 693.210555) - assert results["duration"]["fever_hours"] == 261 - assert results["duration"]["hypothermia_hours"] == 157 + assert results["duration"]["fever_hours"] == 266 + assert results["duration"]["hypothermia_hours"] == 156 + def test_get_all_stats_m3924(m3924_forecast_df): df = m3924_forecast_df @@ -181,5 +218,5 @@ def test_get_all_stats_m3924(m3924_forecast_df): assert np.allclose(results["residual"]["max_residual"], 3.21165) assert np.allclose(results["residual"]["residual_upper_bound"], 0.2931) assert np.allclose(results["duration"]["total_duration_hours"], 248.75) - assert results["duration"]["fever_hours"] == 154 - assert results["duration"]["hypothermia_hours"] == 9 + assert results["duration"]["fever_hours"] == 153 + assert results["duration"]["hypothermia_hours"] == 8 diff --git a/vaxstats/analysis/forecast.py b/vaxstats/analysis/forecast.py index 8dfa57d..64c07fb 100644 --- a/vaxstats/analysis/forecast.py +++ b/vaxstats/analysis/forecast.py @@ -7,9 +7,9 @@ from ..io import load_file from ..utils import get_baseline_df, str_to_datetime -from .hourly import add_hourly_thresholds, calculate_hourly_stats from .residual import add_residuals_col, get_residual_bounds, get_residual_sum_square from .stats import get_column_max, get_column_mean, get_column_min, get_column_std +from .timeframe import add_hourly_thresholds, calculate_stats_by_timeframe def detect_fever_hypothermia( @@ -43,7 +43,9 @@ def detect_fever_hypothermia( logger.info("Detecting fever and hypothermia thresholds") df_baseline = get_baseline_df(df, date_column, date_fmt, baseline) residual_bounds = get_residual_bounds(df_baseline, residual_column) - hourly_stats = calculate_hourly_stats(df, data_column, pred_column, date_column) + hourly_stats = calculate_stats_by_timeframe( + df, "hour", data_column, pred_column, date_column + ) hourly_stats = add_hourly_thresholds(hourly_stats, *residual_bounds) return hourly_stats, residual_bounds @@ -91,7 +93,9 @@ def run_analysis( "std_dev_temp": float(get_column_std(df_baseline, data_column)), "max_temp": float(get_column_max(df_baseline, data_column)), "min_temp": float(get_column_min(df_baseline, data_column)), - "residual_sum_squares": float(get_residual_sum_square(df_baseline, residual_column)), + "residual_sum_squares": float( + get_residual_sum_square(df_baseline, residual_column) + ), } # Compute residual statistics diff --git a/vaxstats/analysis/hourly.py b/vaxstats/analysis/hourly.py deleted file mode 100644 index c590227..0000000 --- a/vaxstats/analysis/hourly.py +++ /dev/null @@ -1,56 +0,0 @@ -import polars as pl - - -def calculate_hourly_stats( - df: pl.DataFrame, - data_column: str = "y", - pred_column: str = "y_hat", - date_column: str = "ds", -) -> pl.DataFrame: - """ - Calculate hourly median temperatures. - - Args: - df: The input DataFrame. - data_column: Name of column containing observed data. - pred_column: Name of column containing predicted data. - date_column: The name of the column with timestamps. - - Returns: - A DataFrame with hourly median temperatures. - """ - return ( - df.with_columns(pl.col(date_column).dt.truncate("1h").alias("hour")) - .group_by("hour") - .agg( - pl.col(data_column).median().alias("y_median"), - pl.col(pred_column).median().alias("y_hat_median"), - pl.col(date_column).min().alias("start_time"), - pl.col(date_column).max().alias("end_time"), - pl.col(data_column).count().alias("data_points"), - ) - .sort("hour") - ) - - -def add_hourly_thresholds( - hourly_stats: pl.DataFrame, residual_lower: float, residual_upper: float -) -> pl.DataFrame: - """ - Calculate and add fever and hypothermia thresholds. - - Args: - hourly_stats: The DataFrame with hourly median temperatures. - residual_lower: The lower residual bound. - residual_upper: The upper residual bound. - - Returns: - A DataFrame with added columns for residual bounds and fever/hypothermia - thresholds. - """ - return hourly_stats.with_columns( - [ - (pl.col("y_hat_median") + residual_lower).alias("hypo_threshold"), - (pl.col("y_hat_median") + residual_upper).alias("fever_threshold"), - ] - ) diff --git a/vaxstats/analysis/timeframe.py b/vaxstats/analysis/timeframe.py new file mode 100644 index 0000000..3a88ecd --- /dev/null +++ b/vaxstats/analysis/timeframe.py @@ -0,0 +1,83 @@ +from typing import Literal + +import polars as pl + + +def calculate_stats_by_timeframe( + df: pl.DataFrame, + timeframe: Literal["hour", "day"], + data_column: str = "y", + pred_column: str = "y_hat", + date_column: str = "ds", + start_from_first: bool = True, +) -> pl.DataFrame: + """ + Calculate statistics by specified timeframe, with an option to start from the first timestamp. + + Args: + df: The input DataFrame. + timeframe: Group statistics by timeframe - options are "hour" and "day". + data_column: Name of column containing observed data. + pred_column: Name of column containing predicted data. + date_column: The name of the column with timestamps. + start_from_first: If `True`, groups by the first timestamp's interval; + otherwise, uses calendar-based intervals. + + Returns: + A DataFrame with aggregated statistics by specified timeframe. + """ + if start_from_first: + # Calculate elapsed time since the first timestamp and divide by duration to create intervals + stats_by_timeframe = df.group_by_dynamic( + index_column=date_column, + every=f"1{timeframe.lower()[0]}", + closed="both", + start_by="datapoint", + ).agg( + pl.col(data_column).median().alias("y_median"), + pl.col(pred_column).median().alias("y_hat_median"), + pl.col(date_column).min().alias("start_time"), + pl.col(date_column).max().alias("end_time"), + pl.col(data_column).count().alias("data_points"), + ) + else: + # Use calendar-based truncation with `truncate` + stats_by_timeframe = ( + df.with_columns( + pl.col(date_column).dt.truncate(f"1{timeframe[0]}").alias(timeframe) + ) + .group_by(timeframe) + .agg( + pl.col(data_column).median().alias("y_median"), + pl.col(pred_column).median().alias("y_hat_median"), + pl.col(date_column).min().alias("start_time"), + pl.col(date_column).max().alias("end_time"), + pl.col(data_column).count().alias("data_points"), + ) + .sort(timeframe) + ) + + return stats_by_timeframe + + +def add_hourly_thresholds( + hourly_stats: pl.DataFrame, residual_lower: float, residual_upper: float +) -> pl.DataFrame: + """ + Calculate and add fever and hypothermia thresholds. + + Args: + hourly_stats: The DataFrame with hourly median temperatures. + residual_lower: The lower residual bound. + residual_upper: The upper residual bound. + + Returns: + A DataFrame with added columns for residual bounds and fever/hypothermia + thresholds. + """ + return hourly_stats.with_columns( + [ + (pl.col("y_hat_median") + residual_lower).alias("hypo_threshold"), + (pl.col("y_hat_median") + residual_upper).alias("fever_threshold"), + ] + ) diff --git a/vaxstats/utils.py b/vaxstats/utils.py index acfa648..02609e5 100644 --- a/vaxstats/utils.py +++ b/vaxstats/utils.py @@ -116,7 +116,9 @@ def str_to_datetime( """ Converts DataFrame datetime column strings to datetimes. """ - df = df.with_columns(pl.col(date_column).str.strptime(pl.Datetime, date_fmt, strict=False)) + df = df.with_columns( + pl.col(date_column).str.strptime(pl.Datetime, date_fmt, strict=False) + ) if df[date_column][0] is None: logger.error(f"Date is Null, please check your `date_fmt` of {date_fmt}") raise RuntimeError