Skip to content

Commit

Permalink
return Naive model for constant series in AutoCES (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Apr 16, 2024
1 parent c3bb82b commit 82a986c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
16 changes: 13 additions & 3 deletions nbs/src/core/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@
"\n",
"from statsforecast.arima import (\n",
" Arima,\n",
" auto_arima_f, forecast_arima, \n",
" fitted_arima, forward_arima\n",
" auto_arima_f,\n",
" fitted_arima,\n",
" forecast_arima, \n",
" forward_arima,\n",
" is_constant,\n",
")\n",
"from statsforecast.ces import (\n",
" auto_ces, forecast_ces,\n",
Expand Down Expand Up @@ -1680,8 +1683,12 @@
" self : \n",
" Complex Exponential Smoothing fitted model.\n",
" \"\"\"\n",
" if is_constant(y):\n",
" model = Naive(alias=self.alias, prediction_intervals=self.prediction_intervals)\n",
" model.fit(y=y, X=X)\n",
" return model\n",
" self.model_ = auto_ces(y, m=self.season_length, model=self.model)\n",
" self.model_['actual_residuals'] = y - self.model_['fitted']\n",
" self.model_['actual_residuals'] = y - self.model_['fitted'] \n",
" self._store_cs(y=y, X=X)\n",
" return self\n",
" \n",
Expand Down Expand Up @@ -1777,6 +1784,9 @@
" forecasts : dict \n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" if is_constant(y):\n",
" model = Naive(alias=self.alias, prediction_intervals=self.prediction_intervals)\n",
" return model.forecast(y=y, h=h, X=X, X_future=X_future, level=level, fitted=fitted)\n",
" mod = auto_ces(y, m=self.season_length, model=self.model)\n",
" fcst = forecast_ces(mod, h, level=level)\n",
" keys = ['mean']\n",
Expand Down
34 changes: 33 additions & 1 deletion statsforecast/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from statsforecast.arima import (
Arima,
auto_arima_f,
forecast_arima,
fitted_arima,
forecast_arima,
forward_arima,
is_constant,
)
from .ces import auto_ces, forecast_ces, forward_ces
from statsforecast.ets import (
Expand Down Expand Up @@ -962,6 +963,12 @@ def fit(
self :
Complex Exponential Smoothing fitted model.
"""
if is_constant(y):
model = Naive(
alias=self.alias, prediction_intervals=self.prediction_intervals
)
model.fit(y=y, X=X)
return model
self.model_ = auto_ces(y, m=self.season_length, model=self.model)
self.model_["actual_residuals"] = y - self.model_["fitted"]
self._store_cs(y=y, X=X)
Expand Down Expand Up @@ -1056,6 +1063,13 @@ def forecast(
forecasts : dict
Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.
"""
if is_constant(y):
model = Naive(
alias=self.alias, prediction_intervals=self.prediction_intervals
)
return model.forecast(
y=y, h=h, X=X, X_future=X_future, level=level, fitted=fitted
)
mod = auto_ces(y, m=self.season_length, model=self.model)
fcst = forecast_ces(mod, h, level=level)
keys = ["mean"]
Expand Down Expand Up @@ -2375,6 +2389,7 @@ def _seasonal_ses_optimized(

# %% ../nbs/src/core/models.ipynb 161
class SeasonalExponentialSmoothingOptimized(_TS):

def __init__(
self,
season_length: int,
Expand Down Expand Up @@ -2579,6 +2594,7 @@ def __init__(
alias: str = "Holt",
prediction_intervals: Optional[ConformalIntervals] = None,
):

self.season_length = season_length
self.error_type = error_type
self.alias = alias
Expand Down Expand Up @@ -2648,6 +2664,7 @@ def _historic_average(

# %% ../nbs/src/core/models.ipynb 204
class HistoricAverage(_TS):

def __init__(
self,
alias: str = "HistoricAverage",
Expand Down Expand Up @@ -2826,6 +2843,7 @@ def forecast(

# %% ../nbs/src/core/models.ipynb 217
class Naive(_TS):

def __init__(
self,
alias: str = "Naive",
Expand Down Expand Up @@ -3054,6 +3072,7 @@ def _random_walk_with_drift(

# %% ../nbs/src/core/models.ipynb 234
class RandomWalkWithDrift(_TS):

def __init__(
self,
alias: str = "RWD",
Expand Down Expand Up @@ -3232,6 +3251,7 @@ def forecast(

# %% ../nbs/src/core/models.ipynb 249
class SeasonalNaive(_TS):

def __init__(
self,
season_length: int,
Expand Down Expand Up @@ -3435,6 +3455,7 @@ def _window_average(

# %% ../nbs/src/core/models.ipynb 265
class WindowAverage(_TS):

def __init__(
self,
window_size: int,
Expand Down Expand Up @@ -3610,6 +3631,7 @@ def _seasonal_window_average(

# %% ../nbs/src/core/models.ipynb 277
class SeasonalWindowAverage(_TS):

def __init__(
self,
season_length: int,
Expand Down Expand Up @@ -3865,6 +3887,7 @@ def _adida(

# %% ../nbs/src/core/models.ipynb 290
class ADIDA(_TS):

def __init__(
self,
alias: str = "ADIDA",
Expand Down Expand Up @@ -4062,6 +4085,7 @@ def _croston_classic(

# %% ../nbs/src/core/models.ipynb 303
class CrostonClassic(_TS):

def __init__(
self,
alias: str = "CrostonClassic",
Expand Down Expand Up @@ -4268,6 +4292,7 @@ def _croston_optimized(

# %% ../nbs/src/core/models.ipynb 315
class CrostonOptimized(_TS):

def __init__(
self,
alias: str = "CrostonOptimized",
Expand Down Expand Up @@ -4441,6 +4466,7 @@ def _croston_sba(

# %% ../nbs/src/core/models.ipynb 327
class CrostonSBA(_TS):

def __init__(
self,
alias: str = "CrostonSBA",
Expand Down Expand Up @@ -4641,6 +4667,7 @@ def _imapa(

# %% ../nbs/src/core/models.ipynb 339
class IMAPA(_TS):

def __init__(
self,
alias: str = "IMAPA",
Expand Down Expand Up @@ -4831,6 +4858,7 @@ def _tsb(

# %% ../nbs/src/core/models.ipynb 351
class TSB(_TS):

def __init__(
self,
alpha_d: float,
Expand Down Expand Up @@ -5058,6 +5086,7 @@ def __init__(
alias: str = "MSTL",
prediction_intervals: Optional[ConformalIntervals] = None,
):

# check ETS model doesnt have seasonality
if repr(trend_forecaster) == "AutoETS":
if trend_forecaster.model[2] != "N":
Expand Down Expand Up @@ -6177,6 +6206,7 @@ def forward(

# %% ../nbs/src/core/models.ipynb 490
class ConstantModel(_TS):

def __init__(self, constant: float, alias: str = "ConstantModel"):
"""Constant Model.
Expand Down Expand Up @@ -6362,6 +6392,7 @@ def forward(

# %% ../nbs/src/core/models.ipynb 504
class ZeroModel(ConstantModel):

def __init__(self, alias: str = "ZeroModel"):
"""Returns Zero forecasts.
Expand All @@ -6376,6 +6407,7 @@ def __init__(self, alias: str = "ZeroModel"):

# %% ../nbs/src/core/models.ipynb 518
class NaNModel(ConstantModel):

def __init__(self, alias: str = "NaNModel"):
"""NaN Model.
Expand Down

0 comments on commit 82a986c

Please sign in to comment.