Skip to content

Commit

Permalink
feat: modularize stats by timeframe
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexmmaldonado committed Nov 9, 2024
1 parent 82d366b commit 6a81a49
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 73 deletions.
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def path_m3924_excel():
def path_example_prepped_csv():
return os.path.join(TEST_DIR, "files/example_prepped.csv")


@pytest.fixture
def path_m3924_prepped_csv():
return os.path.join(TEST_DIR, "files/test-m3924.csv")
Expand All @@ -37,10 +38,12 @@ def path_m3924_prepped_csv():
def path_example_forecast_csv():
return os.path.join(TEST_DIR, "files/example_forecast.csv")


@pytest.fixture
def path_forecast_csv_m3924():
return os.path.join(TEST_DIR, "files/test-m3924-forecast.csv")


@pytest.fixture
def path_example_img():
return os.path.join(TEST_DIR, "tmp/path_example_img.png")
Expand All @@ -51,11 +54,13 @@ def example_forecast_df(path_example_forecast_csv):
df = load_file(path_example_forecast_csv, file_type="csv")
return df


@pytest.fixture
def m3924_forecast_df(path_forecast_csv_m3924):
df = load_file(path_forecast_csv_m3924, file_type="csv")
return df


@pytest.fixture
def example_forecast_df_baseline(path_example_forecast_csv):
baseline_days = 7.0
Expand Down
1 change: 1 addition & 0 deletions tests/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_arima_cli(capsys, monkeypatch):
print(captured)
assert os.path.exists(output_path), "Output file was not created"


def test_arima_m3924(path_m3924_prepped_csv):
model_kwargs = {
"order": (0, 0, 10),
Expand Down
63 changes: 50 additions & 13 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import polars as pl

from vaxstats.analysis.forecast import detect_fever_hypothermia, run_analysis
from vaxstats.analysis.hourly import add_hourly_thresholds, calculate_hourly_stats
from vaxstats.analysis.residual import add_residuals_col, get_residual_bounds
from vaxstats.analysis.stats import get_column_stat, get_column_stats
from vaxstats.analysis.timeframe import (
add_hourly_thresholds,
calculate_stats_by_timeframe,
)
from vaxstats.utils import get_baseline_df, str_to_datetime


Expand Down Expand Up @@ -43,13 +46,15 @@ def test_residual_bounds(example_forecast_df_baseline):
)


def test_calculate_hourly_stats(example_forecast_df_baseline):
def test_calculate_stats_by_hour(example_forecast_df_baseline):
df = example_forecast_df_baseline
hourly_stats = calculate_hourly_stats(
hourly_stats = calculate_stats_by_timeframe(
df,
timeframe="hour",
data_column="y",
pred_column="y_hat",
date_column="ds",
start_from_first=False,
)

assert "y_median" in hourly_stats.columns
Expand All @@ -75,6 +80,37 @@ def test_calculate_hourly_stats(example_forecast_df_baseline):
assert np.allclose(temps_hourly_median[-1], np.array([38.8151678]))


def test_calculate_stats_by_day(example_forecast_df_baseline):
df = example_forecast_df_baseline
stats = calculate_stats_by_timeframe(
df,
timeframe="day",
data_column="y",
pred_column="y_hat",
date_column="ds",
start_from_first=True,
)

assert "y_median" in stats.columns
assert "y_hat_median" in stats.columns
assert "start_time" in stats.columns
assert "end_time" in stats.columns
assert "data_points" in stats.columns

# Check if start_time is always less than or equal to end_time
assert (stats["start_time"] <= stats["end_time"]).all()

# Check if data_points is always positive
assert (stats["data_points"] > 0).all()

temps_median = stats["y_median"].to_numpy()
assert np.allclose(
temps_median[:4],
np.array([37.635, 37.707, 37.62, 37.501]),
)
assert np.allclose(temps_median[-1], np.array([37.2575]))


def test_calculate_thresholds(example_forecast_df):
df = example_forecast_df
df = add_residuals_col(df)
Expand All @@ -86,8 +122,8 @@ def test_calculate_thresholds(example_forecast_df):
df_baseline = get_baseline_df(df, baseline=baseline_hours)

residual_lower, residual_upper = get_residual_bounds(df_baseline)
hourly_stats = calculate_hourly_stats(
df, data_column="y", pred_column="y_hat", date_column="ds"
hourly_stats = calculate_stats_by_timeframe(
df, timeframe="hour", data_column="y", pred_column="y_hat", date_column="ds"
)
hourly_stats = add_hourly_thresholds(hourly_stats, residual_lower, residual_upper)

Expand All @@ -99,10 +135,10 @@ def test_calculate_thresholds(example_forecast_df):
hypo_threshold = hourly_stats["hypo_threshold"].to_numpy()

assert np.allclose(
fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221])
fever_threshold[:3], np.array([39.09025125, 39.22112064, 38.65119125])
)
assert np.allclose(
hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112])
hypo_threshold[:3], np.array([38.56102016, 38.69188955, 38.12196017])
)


Expand All @@ -127,10 +163,10 @@ def test_detect_fever_hypothermia(example_forecast_df):
hypo_threshold = hourly_stats["hypo_threshold"].to_numpy()

assert np.allclose(
fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221])
fever_threshold[:3], np.array([39.09025125, 39.22112064, 38.65119125])
)
assert np.allclose(
hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112])
hypo_threshold[:3], np.array([38.56102016, 38.69188955, 38.12196017])
)


Expand All @@ -156,8 +192,9 @@ def test_get_all_stats(example_forecast_df):
assert np.allclose(results["residual"]["max_residual"], 2.70556)
assert np.allclose(results["residual"]["residual_upper_bound"], 0.264615542)
assert np.allclose(results["duration"]["total_duration_hours"], 693.210555)
assert results["duration"]["fever_hours"] == 261
assert results["duration"]["hypothermia_hours"] == 157
assert results["duration"]["fever_hours"] == 266
assert results["duration"]["hypothermia_hours"] == 156


def test_get_all_stats_m3924(m3924_forecast_df):
df = m3924_forecast_df
Expand All @@ -181,5 +218,5 @@ def test_get_all_stats_m3924(m3924_forecast_df):
assert np.allclose(results["residual"]["max_residual"], 3.21165)
assert np.allclose(results["residual"]["residual_upper_bound"], 0.2931)
assert np.allclose(results["duration"]["total_duration_hours"], 248.75)
assert results["duration"]["fever_hours"] == 154
assert results["duration"]["hypothermia_hours"] == 9
assert results["duration"]["fever_hours"] == 153
assert results["duration"]["hypothermia_hours"] == 8
10 changes: 7 additions & 3 deletions vaxstats/analysis/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from ..io import load_file
from ..utils import get_baseline_df, str_to_datetime
from .hourly import add_hourly_thresholds, calculate_hourly_stats
from .residual import add_residuals_col, get_residual_bounds, get_residual_sum_square
from .stats import get_column_max, get_column_mean, get_column_min, get_column_std
from .timeframe import add_hourly_thresholds, calculate_stats_by_timeframe


def detect_fever_hypothermia(
Expand Down Expand Up @@ -43,7 +43,9 @@ def detect_fever_hypothermia(
logger.info("Detecting fever and hypothermia thresholds")
df_baseline = get_baseline_df(df, date_column, date_fmt, baseline)
residual_bounds = get_residual_bounds(df_baseline, residual_column)
hourly_stats = calculate_hourly_stats(df, data_column, pred_column, date_column)
hourly_stats = calculate_stats_by_timeframe(
df, "hour", data_column, pred_column, date_column
)
hourly_stats = add_hourly_thresholds(hourly_stats, *residual_bounds)
return hourly_stats, residual_bounds

Expand Down Expand Up @@ -91,7 +93,9 @@ def run_analysis(
"std_dev_temp": float(get_column_std(df_baseline, data_column)),
"max_temp": float(get_column_max(df_baseline, data_column)),
"min_temp": float(get_column_min(df_baseline, data_column)),
"residual_sum_squares": float(get_residual_sum_square(df_baseline, residual_column)),
"residual_sum_squares": float(
get_residual_sum_square(df_baseline, residual_column)
),
}

# Compute residual statistics
Expand Down
56 changes: 0 additions & 56 deletions vaxstats/analysis/hourly.py

This file was deleted.

83 changes: 83 additions & 0 deletions vaxstats/analysis/timeframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Literal

import polars as pl


def calculate_stats_by_timeframe(
df: pl.DataFrame,
timeframe: Literal["hour", "day"],
data_column: str = "y",
pred_column: str = "y_hat",
date_column: str = "ds",
start_from_first: bool = True,
) -> pl.DataFrame:
"""
Calculate statistics by specified timeframe, with an option to start from the first timestamp.
Args:
df: The input DataFrame.
timeframe: Group statistics by timeframe - options are "hour" and "day".
data_column: Name of column containing observed data.
pred_column: Name of column containing predicted data.
date_column: The name of the column with timestamps.
start_from_first: If `True`, groups by the first timestamp's interval;
otherwise, uses calendar-based intervals.
Returns:
A DataFrame with aggregated statistics by specified timeframe.
"""
if start_from_first:
# Calculate elapsed time since the first timestamp and divide by duration to create intervals
stats_by_timeframe = df.group_by_dynamic(
index_column=date_column,
every=f"1{timeframe.lower()[0]}",
closed="both",
start_by="datapoint",
).agg(
pl.col(data_column).median().alias("y_median"),
pl.col(pred_column).median().alias("y_hat_median"),
pl.col(date_column).min().alias("start_time"),
pl.col(date_column).max().alias("end_time"),
pl.col(data_column).count().alias("data_points"),
)
else:
# Use calendar-based truncation with `truncate`
stats_by_timeframe = (
df.with_columns(
pl.col(date_column).dt.truncate(f"1{timeframe[0]}").alias(timeframe)
)
.group_by(timeframe)
.agg(
pl.col(data_column).median().alias("y_median"),
pl.col(pred_column).median().alias("y_hat_median"),
pl.col(date_column).min().alias("start_time"),
pl.col(date_column).max().alias("end_time"),
pl.col(data_column).count().alias("data_points"),
)
.sort(timeframe)
)

return stats_by_timeframe


def add_hourly_thresholds(
hourly_stats: pl.DataFrame, residual_lower: float, residual_upper: float
) -> pl.DataFrame:
"""
Calculate and add fever and hypothermia thresholds.
Args:
hourly_stats: The DataFrame with hourly median temperatures.
residual_lower: The lower residual bound.
residual_upper: The upper residual bound.
Returns:
A DataFrame with added columns for residual bounds and fever/hypothermia
thresholds.
"""
return hourly_stats.with_columns(
[
(pl.col("y_hat_median") + residual_lower).alias("hypo_threshold"),
(pl.col("y_hat_median") + residual_upper).alias("fever_threshold"),
]
)
4 changes: 3 additions & 1 deletion vaxstats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def str_to_datetime(
"""
Converts DataFrame datetime column strings to datetimes.
"""
df = df.with_columns(pl.col(date_column).str.strptime(pl.Datetime, date_fmt, strict=False))
df = df.with_columns(
pl.col(date_column).str.strptime(pl.Datetime, date_fmt, strict=False)
)
if df[date_column][0] is None:
logger.error(f"Date is Null, please check your `date_fmt` of {date_fmt}")
raise RuntimeError
Expand Down

0 comments on commit 6a81a49

Please sign in to comment.