Skip to content

Commit

Permalink
refactor: standardize str to datetime
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexmmaldonado committed Aug 22, 2024
1 parent a694d8e commit 3d03401
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 5 deletions.
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def path_example_img():
@pytest.fixture
def example_forecast_df(path_example_forecast_csv):
df = load_file(path_example_forecast_csv, file_type="csv")
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")
return df


Expand Down
2 changes: 2 additions & 0 deletions tests/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from vaxstats.cli import main
from vaxstats.forecast import run_forecasting
from vaxstats.io import load_file
from vaxstats.utils import str_to_datetime


def test_arima(path_example_prepped_csv):
Expand All @@ -15,6 +16,7 @@ def test_arima(path_example_prepped_csv):
"method": "CSS-ML", # CSS-ML, ML, CSS
}
df = load_file(path_example_prepped_csv, "csv")
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")
baseline_days = 7.0
baseline_hours = 24 * baseline_days
df = run_forecasting(
Expand Down
8 changes: 6 additions & 2 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vaxstats.analysis.hourly import add_hourly_thresholds, calculate_hourly_stats
from vaxstats.analysis.residual import add_residuals_col, get_residual_bounds
from vaxstats.analysis.stats import get_column_stat, get_column_stats
from vaxstats.utils import get_baseline_df
from vaxstats.utils import get_baseline_df, str_to_datetime


def test_get_column_stat(example_forecast_df):
Expand Down Expand Up @@ -44,8 +44,9 @@ def test_residual_bounds(example_forecast_df_baseline):


def test_calculate_hourly_stats(example_forecast_df_baseline):
df = example_forecast_df_baseline
hourly_stats = calculate_hourly_stats(
example_forecast_df_baseline,
df,
data_column="y",
pred_column="y_hat",
date_column="ds",
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_calculate_hourly_stats(example_forecast_df_baseline):
def test_calculate_thresholds(example_forecast_df, baseline_hours):
df = example_forecast_df
df = add_residuals_col(df)
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")
df_baseline = get_baseline_df(df, baseline=baseline_hours)

residual_lower, residual_upper = get_residual_bounds(df_baseline)
Expand All @@ -103,6 +105,7 @@ def test_calculate_thresholds(example_forecast_df, baseline_hours):
def test_detect_fever_hypothermia(example_forecast_df, baseline_hours):
df = example_forecast_df
df = add_residuals_col(df)
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")
hourly_stats, residual_bounds = detect_fever_hypothermia(
df, baseline=baseline_hours
)
Expand All @@ -126,6 +129,7 @@ def test_detect_fever_hypothermia(example_forecast_df, baseline_hours):
def test_get_all_stats(example_forecast_df, baseline_hours):
df = example_forecast_df
df = add_residuals_col(df)
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")

results = run_analysis(
df,
Expand Down
2 changes: 1 addition & 1 deletion vaxstats/analysis/hourly.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def calculate_hourly_stats(
"""
return (
df.with_columns(pl.col(date_column).dt.truncate("1h").alias("hour"))
.groupby("hour")
.group_by("hour")
.agg(
pl.col(data_column).median().alias("y_median"),
pl.col(pred_column).median().alias("y_hat_median"),
Expand Down
3 changes: 2 additions & 1 deletion vaxstats/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .analysis.residual import add_residuals_col
from .io import load_file
from .log import run_with_progress_logging
from .utils import split_df
from .utils import split_df, str_to_datetime


def run_forecasting(
Expand Down Expand Up @@ -81,6 +81,7 @@ def run_forecasting(
def cli_forecast(args, sf_model_args=(), sf_model_kwargs={}):
# Load the input file
df = load_file(args.file_path)
df = str_to_datetime(df, date_column="ds", date_fmt="%Y-%m-%d %H:%M:%S")

# Import the forecasting model class dynamically
module_name, class_name = args.sf_model.rsplit(".", 1)
Expand Down

0 comments on commit 3d03401

Please sign in to comment.