From c5172e6401bd86ef8a964d374f57aadae16f80aa Mon Sep 17 00:00:00 2001 From: Eddy Zhiqiang Ji Date: Tue, 24 Sep 2024 18:26:18 -0400 Subject: [PATCH] Fixed date pair generator in train_pipeline.py --- src/model_config.py | 2 +- src/train_pipeline.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/model_config.py b/src/model_config.py index 6d6171f..33550cb 100644 --- a/src/model_config.py +++ b/src/model_config.py @@ -11,7 +11,7 @@ INITIAL_TRAIN_START = "2018-01-01" INITIAL_TRAIN_YEARS = 4 #year(s) FINAL_TRAIN_START = "2019-11-01" -PREDICTION_PERIOD = 3# months +PREDICTION_PERIOD = 3 # months STEP_MONTHS = 1 # month(s) MAX_PRED_END = "2024-02-01" diff --git a/src/train_pipeline.py b/src/train_pipeline.py index 0a4909b..5d3a920 100644 --- a/src/train_pipeline.py +++ b/src/train_pipeline.py @@ -8,7 +8,7 @@ from dateutil.relativedelta import relativedelta from colorlog import ColoredFormatter from config import PROJ_FOLDER -from model_config import INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, FINAL_TRAIN_START, STEP_MONTHS, MAX_PRED_END, scheme_name +from model_config import INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, FINAL_TRAIN_START, PREDICTION_PERIOD, STEP_MONTHS, MAX_PRED_END, scheme_name def setup_logger(): formatter = ColoredFormatter( @@ -33,7 +33,7 @@ def setup_logger(): setup_logger() -def generate_date_pairs(start_str, initial_train_years, step_months, final_train_start, max_pred_end=MAX_PRED_END): +def generate_date_pairs(start_str, initial_train_years, prediction_period, step_months, final_train_start, max_pred_end=MAX_PRED_END): """Generate train and prediction date pairs based on the input parameters.""" # Convert string dates to datetime objects current_train_start = datetime.strptime(start_str, "%Y-%m-%d") @@ -46,7 +46,7 @@ def generate_date_pairs(start_str, initial_train_years, step_months, final_train # Set train_end to one year after train_start minus one day train_end = (current_train_start + relativedelta(years=initial_train_years) - timedelta(days=1)).strftime('%Y-%m-%d') pred_start = (current_train_start + relativedelta(years=initial_train_years)).strftime('%Y-%m-%d') - pred_end_date = (current_train_start + relativedelta(years=initial_train_years, months=step_months)) + pred_end_date = (current_train_start + relativedelta(years=initial_train_years, months=prediction_period)) # Ensure pred_end does not exceed the maximum allowed date if pred_end_date > max_pred_end: @@ -69,7 +69,7 @@ def main(): writer = csv.writer(file) writer.writerow(['Train Start', 'Train End', 'Pred Start', 'Pred End', 'Test F1', 'Pred F1']) - for train_start, train_end, pred_start, pred_end in generate_date_pairs(INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, STEP_MONTHS, FINAL_TRAIN_START): + for train_start, train_end, pred_start, pred_end in generate_date_pairs(INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, PREDICTION_PERIOD, STEP_MONTHS, FINAL_TRAIN_START): print(f"Training dataset: {train_start} to {train_end}, Prediction dataset: {pred_start} to {pred_end}") try: