From 17a895be93fb0226f97d9c5cf4ba8b4cd3abd8cc Mon Sep 17 00:00:00 2001 From: t-minus Date: Fri, 10 May 2024 02:18:04 +0800 Subject: [PATCH] Fix missing required parameters --- nbs/common.base_multivariate.ipynb | 4 ++++ nbs/common.base_recurrent.ipynb | 4 ++++ nbs/common.base_windows.ipynb | 4 ++++ neuralforecast/common/_base_multivariate.py | 4 ++++ neuralforecast/common/_base_recurrent.py | 4 ++++ neuralforecast/common/_base_windows.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/nbs/common.base_multivariate.ipynb b/nbs/common.base_multivariate.ipynb index 959c047b2..c5e09ddab 100644 --- a/nbs/common.base_multivariate.ipynb +++ b/nbs/common.base_multivariate.ipynb @@ -107,6 +107,8 @@ " alias=None,\n", " optimizer=None,\n", " optimizer_kwargs=None,\n", + " lr_scheduler=None,\n", + " lr_scheduler_kwargs=None,\n", " **trainer_kwargs):\n", " super().__init__(\n", " random_seed=random_seed,\n", @@ -114,6 +116,8 @@ " valid_loss=valid_loss,\n", " optimizer=optimizer,\n", " optimizer_kwargs=optimizer_kwargs,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_kwargs=lr_scheduler_kwargs, \n", " futr_exog_list=futr_exog_list,\n", " hist_exog_list=hist_exog_list,\n", " stat_exog_list=stat_exog_list,\n", diff --git a/nbs/common.base_recurrent.ipynb b/nbs/common.base_recurrent.ipynb index 835242309..433451676 100644 --- a/nbs/common.base_recurrent.ipynb +++ b/nbs/common.base_recurrent.ipynb @@ -112,6 +112,8 @@ " alias=None,\n", " optimizer=None,\n", " optimizer_kwargs=None,\n", + " lr_scheduler=None,\n", + " lr_scheduler_kwargs=None,\n", " **trainer_kwargs):\n", " super().__init__(\n", " random_seed=random_seed,\n", @@ -119,6 +121,8 @@ " valid_loss=valid_loss,\n", " optimizer=optimizer,\n", " optimizer_kwargs=optimizer_kwargs,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_kwargs=lr_scheduler_kwargs,\n", " futr_exog_list=futr_exog_list,\n", " hist_exog_list=hist_exog_list,\n", " stat_exog_list=stat_exog_list,\n", diff --git a/nbs/common.base_windows.ipynb b/nbs/common.base_windows.ipynb index f4b1da83b..00d9100f2 100644 --- a/nbs/common.base_windows.ipynb +++ b/nbs/common.base_windows.ipynb @@ -117,6 +117,8 @@ " alias=None,\n", " optimizer=None,\n", " optimizer_kwargs=None,\n", + " lr_scheduler=None,\n", + " lr_scheduler_kwargs=None,\n", " **trainer_kwargs):\n", " super().__init__(\n", " random_seed=random_seed,\n", @@ -124,6 +126,8 @@ " valid_loss=valid_loss,\n", " optimizer=optimizer,\n", " optimizer_kwargs=optimizer_kwargs,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_kwargs=lr_scheduler_kwargs,\n", " futr_exog_list=futr_exog_list,\n", " hist_exog_list=hist_exog_list,\n", " stat_exog_list=stat_exog_list,\n", diff --git a/neuralforecast/common/_base_multivariate.py b/neuralforecast/common/_base_multivariate.py index 70802c7af..cdc0a1bc2 100644 --- a/neuralforecast/common/_base_multivariate.py +++ b/neuralforecast/common/_base_multivariate.py @@ -52,6 +52,8 @@ def __init__( alias=None, optimizer=None, optimizer_kwargs=None, + lr_scheduler=None, + lr_scheduler_kwargs=None, **trainer_kwargs, ): super().__init__( @@ -60,6 +62,8 @@ def __init__( valid_loss=valid_loss, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, futr_exog_list=futr_exog_list, hist_exog_list=hist_exog_list, stat_exog_list=stat_exog_list, diff --git a/neuralforecast/common/_base_recurrent.py b/neuralforecast/common/_base_recurrent.py index 334f22e8a..6f43bfd86 100644 --- a/neuralforecast/common/_base_recurrent.py +++ b/neuralforecast/common/_base_recurrent.py @@ -51,6 +51,8 @@ def __init__( alias=None, optimizer=None, optimizer_kwargs=None, + lr_scheduler=None, + lr_scheduler_kwargs=None, **trainer_kwargs, ): super().__init__( @@ -59,6 +61,8 @@ def __init__( valid_loss=valid_loss, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, futr_exog_list=futr_exog_list, hist_exog_list=hist_exog_list, stat_exog_list=stat_exog_list, diff --git a/neuralforecast/common/_base_windows.py b/neuralforecast/common/_base_windows.py index ea2ba71bb..f9543bf79 100644 --- a/neuralforecast/common/_base_windows.py +++ b/neuralforecast/common/_base_windows.py @@ -55,6 +55,8 @@ def __init__( alias=None, optimizer=None, optimizer_kwargs=None, + lr_scheduler=None, + lr_scheduler_kwargs=None, **trainer_kwargs, ): super().__init__( @@ -63,6 +65,8 @@ def __init__( valid_loss=valid_loss, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, futr_exog_list=futr_exog_list, hist_exog_list=hist_exog_list, stat_exog_list=stat_exog_list,