Skip to content

Commit

Permalink
[Minor] Improve Season glocal reg invalid parameter handling (#1601)
Browse files Browse the repository at this point in the history
* reduce warning

* assertionError for wrong values
  • Loading branch information
ourownstory authored Jun 26, 2024
1 parent 92a0198 commit e90bf5a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
6 changes: 2 additions & 4 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,10 @@ def __post_init__(self):
}
)

if self.seasonality_local_reg < 0:
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False
assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg)

Check failure on line 393 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

if self.seasonality_local_reg is True:
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global but local regularization is set.
Expand Down
23 changes: 22 additions & 1 deletion tests/test_glocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib

import pandas as pd
import pytest

from neuralprophet import NeuralProphet

Expand Down Expand Up @@ -341,14 +342,16 @@ def test_glocal_seasonality_reg():
df2_0["ID"] = "df2"
df3_0 = df.iloc[256:384, :].copy(deep=True)
df3_0["ID"] = "df3"
for _ in [-30, 0, False, True]:
for coef_i in [0, 1.5, False, True]:

m = NeuralProphet(
n_forecasts=1,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
season_global_local="local",
yearly_seasonality_glocal_mode="global",
seasonality_local_reg=coef_i,
)

m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global")
Expand All @@ -359,6 +362,24 @@ def test_glocal_seasonality_reg():
metrics = m.test(test_df)
log.info(f"forecast = {forecast}, metrics = {metrics}")

with pytest.raises(AssertionError, match="Invalid seasonality_local_reg"):
m = NeuralProphet(
n_forecasts=1,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
season_global_local="local",
yearly_seasonality_glocal_mode="global",
seasonality_local_reg=-324,
)

m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global")
train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True)
m.fit(train_df)
future = m.make_future_dataframe(test_df, n_historic_predictions=True)
forecast = m.predict(future)
metrics = m.test(test_df)


def test_trend_local_reg_if_global():
# SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES
Expand Down

0 comments on commit e90bf5a

Please sign in to comment.