Skip to content

Commit

Permalink
Fixed date pair generator in train_pipeline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
qinip committed Sep 24, 2024
1 parent 3cf7e01 commit c5172e6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
INITIAL_TRAIN_START = "2018-01-01"
INITIAL_TRAIN_YEARS = 4 #year(s)
FINAL_TRAIN_START = "2019-11-01"
PREDICTION_PERIOD = 3# months
PREDICTION_PERIOD = 3 # months
STEP_MONTHS = 1 # month(s)
MAX_PRED_END = "2024-02-01"

Expand Down
8 changes: 4 additions & 4 deletions src/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dateutil.relativedelta import relativedelta
from colorlog import ColoredFormatter
from config import PROJ_FOLDER
from model_config import INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, FINAL_TRAIN_START, STEP_MONTHS, MAX_PRED_END, scheme_name
from model_config import INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, FINAL_TRAIN_START, PREDICTION_PERIOD, STEP_MONTHS, MAX_PRED_END, scheme_name

def setup_logger():
formatter = ColoredFormatter(
Expand All @@ -33,7 +33,7 @@ def setup_logger():
setup_logger()


def generate_date_pairs(start_str, initial_train_years, step_months, final_train_start, max_pred_end=MAX_PRED_END):
def generate_date_pairs(start_str, initial_train_years, prediction_period, step_months, final_train_start, max_pred_end=MAX_PRED_END):
"""Generate train and prediction date pairs based on the input parameters."""
# Convert string dates to datetime objects
current_train_start = datetime.strptime(start_str, "%Y-%m-%d")
Expand All @@ -46,7 +46,7 @@ def generate_date_pairs(start_str, initial_train_years, step_months, final_train
# Set train_end to one year after train_start minus one day
train_end = (current_train_start + relativedelta(years=initial_train_years) - timedelta(days=1)).strftime('%Y-%m-%d')
pred_start = (current_train_start + relativedelta(years=initial_train_years)).strftime('%Y-%m-%d')
pred_end_date = (current_train_start + relativedelta(years=initial_train_years, months=step_months))
pred_end_date = (current_train_start + relativedelta(years=initial_train_years, months=prediction_period))

# Ensure pred_end does not exceed the maximum allowed date
if pred_end_date > max_pred_end:
Expand All @@ -69,7 +69,7 @@ def main():
writer = csv.writer(file)
writer.writerow(['Train Start', 'Train End', 'Pred Start', 'Pred End', 'Test F1', 'Pred F1'])

for train_start, train_end, pred_start, pred_end in generate_date_pairs(INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, STEP_MONTHS, FINAL_TRAIN_START):
for train_start, train_end, pred_start, pred_end in generate_date_pairs(INITIAL_TRAIN_START, INITIAL_TRAIN_YEARS, PREDICTION_PERIOD, STEP_MONTHS, FINAL_TRAIN_START):
print(f"Training dataset: {train_start} to {train_end}, Prediction dataset: {pred_start} to {pred_end}")

try:
Expand Down

0 comments on commit c5172e6

Please sign in to comment.