Skip to content

Commit

Permalink
tests: more relevant values
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexmmaldonado committed Aug 21, 2024
1 parent e8d7185 commit ebf41b1
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
)
from vaxstats.utils import split_df

baseline_days = 7.0
baseline_hours = 24 * baseline_days


def test_get_column_stat(example_forecast_df):
result = get_column_stat(example_forecast_df, "y", pl.mean)
Expand All @@ -38,8 +41,6 @@ def test_baseline_stats(example_forecast_df):
def test_residual_stats(example_forecast_df):
df = example_forecast_df

baseline_days = 7.0
baseline_hours = 24 * baseline_days
df = split_df(df, hours=baseline_hours)[0]

df = add_residuals_col(df)
Expand All @@ -52,8 +53,6 @@ def test_residual_stats(example_forecast_df):
def test_residual_bounds(example_forecast_df):
df = example_forecast_df
df = add_residuals_col(df)
baseline_days = 7.0
baseline_hours = 24 * baseline_days
residual_bounds = get_residual_bounds(df, baseline=baseline_hours)
assert np.allclose(
np.array(residual_bounds), np.array((-0.26462, 0.26462)), atol=0.0001
Expand Down Expand Up @@ -89,7 +88,9 @@ def test_calculate_hourly_stats(example_forecast_df):
def test_calculate_thresholds(example_forecast_df):
example_forecast_df = add_residuals_col(example_forecast_df)
hourly_stats = calculate_hourly_stats(example_forecast_df, data_column="y_hat")
residual_lower, residual_upper = get_residual_bounds(example_forecast_df)
residual_lower, residual_upper = get_residual_bounds(
example_forecast_df, baseline=baseline_hours
)
result = calculate_thresholds(hourly_stats, residual_upper, residual_lower)

assert result.shape[1] == 7
Expand All @@ -99,25 +100,44 @@ def test_calculate_thresholds(example_forecast_df):
fever_threshold = result["fever_threshold"].to_numpy()
hypo_threshold = result["hypo_threshold"].to_numpy()

assert np.allclose(fever_threshold[:3], np.array([41.4767, 41.2973, 41.2594]))
assert np.allclose(hypo_threshold[:3], np.array([36.4484, 36.2690, 36.2310]))
assert np.allclose(
fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221])
)
assert np.allclose(
hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112])
)


def test_detect_fever_hypothermia(example_forecast_df):
example_forecast_df = add_residuals_col(example_forecast_df)
result = detect_fever_hypothermia(example_forecast_df)
print(result)
result = detect_fever_hypothermia(example_forecast_df, baseline=baseline_hours)

assert result.shape[1] == 7 # Should have 5 columns
assert "hourly_median_temp" in result.columns
assert "fever_threshold" in result.columns
assert "hypo_threshold" in result.columns

fever_threshold = result["fever_threshold"].to_numpy()
hypo_threshold = result["hypo_threshold"].to_numpy()

assert np.allclose(
fever_threshold[:3], np.array([39.22711414, 39.0477938, 39.00983221])
)
assert np.allclose(
hypo_threshold[:3], np.array([38.69788305, 38.51856271, 38.48060112])
)


def test_get_all_stats(example_forecast_df):
df = add_residuals_col(example_forecast_df)
hourly_stats = detect_fever_hypothermia(df, pred_column="y_hat")
residual_bounds = get_residual_bounds(df)
print(hourly_stats)
baseline_days = 7.0
baseline_hours = 24 * baseline_days

hourly_stats = detect_fever_hypothermia(
df, pred_column="y_hat", baseline=baseline_hours
)

residual_bounds = get_residual_bounds(df, baseline=baseline_hours)

stats_dict = compute_stats_dict(
df,
Expand All @@ -126,3 +146,4 @@ def test_get_all_stats(example_forecast_df):
hourly_stats=hourly_stats,
residual_bounds=residual_bounds,
)
print(stats_dict)

0 comments on commit ebf41b1

Please sign in to comment.