Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add example of modifying the default configure_optimizers() behavior (use of ReduceLROnPlateau scheduler) #1015

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bf2e11d
Review: Add warning if users pass only optimizer_kwargs but not optim…
JQGoh May 23, 2024
7d83799
option to modify configure_optimizers
JQGoh May 23, 2024
ce6848c
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Aug 2, 2024
3539eb3
Add example on ReduceLROnPlateau
JQGoh Aug 2, 2024
c0a7eb8
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Dec 11, 2024
2e3d7c6
Remove old interface and deprecate the arguments
JQGoh Dec 11, 2024
eb17a4d
Fix test
JQGoh Dec 11, 2024
f5a4f62
Correction to the path in contributing.md note
JQGoh Dec 18, 2024
829d80a
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Dec 21, 2024
3bab7c0
Review: Allow users to specifify the configure_optimizers() for
JQGoh Dec 22, 2024
c5f20a1
Add arguments and doc string
JQGoh Dec 22, 2024
a817907
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Jan 21, 2025
a5ca43e
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Jan 27, 2025
5fb6331
Revert "Add arguments and doc string"
JQGoh Jan 28, 2025
3ca378a
Revert "Review: Allow users to specifify the configure_optimizers() for"
JQGoh Jan 28, 2025
5c8e1c2
Revert "Fix test"
JQGoh Jan 28, 2025
cc317cf
Revert "Remove old interface and deprecate the arguments"
JQGoh Jan 28, 2025
0266daf
Revert "Add example on ReduceLROnPlateau"
JQGoh Jan 28, 2025
5765c20
Revert "option to modify configure_optimizers"
JQGoh Jan 28, 2025
0d07a7a
Omit unnecessary changes
JQGoh Jan 28, 2025
684184f
Add documentation on customizing configure_optimizers()
JQGoh Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into feat/modify-config-optimizers
JQGoh committed Aug 2, 2024
commit ce6848c7f8003d728380694a51b5721da92e355f
38 changes: 28 additions & 10 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@
" kaiming_normal = nn.init.kaiming_normal_\n",
" xavier_uniform = nn.init.xavier_uniform_\n",
" xavier_normal = nn.init.xavier_normal_\n",
" \n",
" \n",
" nn.init.kaiming_uniform_ = noop\n",
" nn.init.kaiming_normal_ = noop\n",
" nn.init.xavier_uniform_ = noop\n",
@@ -156,6 +156,12 @@
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
"\n",
" # lr scheduler\n",
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # customized by set_configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
@@ -408,17 +414,28 @@
" if self.optimizer_kwargs:\n",
" warnings.warn(\n",
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
" ) \n",
" )\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
" \n",
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
" if self.lr_scheduler:\n",
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
" if 'optimizer' in lr_scheduler_kwargs:\n",
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
" del lr_scheduler_kwargs['optimizer']\n",
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
" else:\n",
" if self.lr_scheduler_kwargs:\n",
" warnings.warn(\n",
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
" ) \n",
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" ),\n",
" \"frequency\": 1,\n",
" \"interval\": \"step\",\n",
" }\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
" \n",
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
"\n",
" def set_configure_optimizers(\n",
" self, \n",
" optimizer=None,\n",
@@ -444,6 +461,7 @@
" 'strict': strict,\n",
" 'name': name,\n",
" }\n",
"\n",
" if scheduler is not None and optimizer is not None:\n",
" if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")\n",
139 changes: 128 additions & 11 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
@@ -3007,7 +3007,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ddeec474",
"id": "c97858b5-e6a0-4353-a48f-5a5460eb2314",
"metadata": {},
"outputs": [],
"source": [
@@ -3018,7 +3018,7 @@
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24,\n",
" \"input_size\": 24, \n",
" \"max_steps\": 1,\n",
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
" }\n",
@@ -3035,7 +3035,113 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fef9925d-f80c-4851-ba35-4fd6e20162db",
"id": "24142322",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test customized lr_scheduler behavior such that the user defined lr_scheduler result should differ from default\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" nf.fit(AirPassengersPanel_train)\n",
" default_optimizer_predict = nf.predict()\n",
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" # using a customized lr_scheduler, default is StepLR\n",
" params.update({\n",
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR,\n",
" \"lr_scheduler_kwargs\": {\"factor\": 0.78}, \n",
" })\n",
" models2 = [nf_model(**params)]\n",
" nf2 = NeuralForecast(models=models2, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
" customized_optimizer_predict = nf2.predict()\n",
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
" assert mean2 != mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54c7b5e2",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if the user-defined lr_scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n",
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1d8bebb",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass in \"optimizer\" parameter, we expect warning and it ignores them\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1, \n",
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR, \n",
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
" }\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\" in str(w.message) for w in issued_warnings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06febece",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass in \"lr_scheduler_kwargs\" but not \"lr_scheduler\", we expect a warning\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1,\n",
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
" }\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\" in str(w.message) for w in issued_warnings)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b942face",
"metadata": {},
"outputs": [],
"source": [
@@ -3059,6 +3165,7 @@
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
"\n",
" )\n",
" nf2 = NeuralForecast(models=models, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
@@ -3077,16 +3184,26 @@
" customized_predict3 = nf3.predict()\n",
" mean3 = customized_predict3.loc[:, nf_model.__name__].mean()\n",
" assert mean3 != mean\n",
"\n",
" # test ReduceLROnPlateau\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode='min', factor=0.5, patience=2,\n",
" )\n",
" \n",
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" monitor=\"val_loss\",\n",
" )\n",
" nf3 = NeuralForecast(models=models, freq='M')\n",
" nf3.fit(AirPassengersPanel_train)\n",
" customized_predict3 = nf3.predict()\n",
" mean3 = customized_predict3.loc[:, nf_model.__name__].mean()\n",
" assert mean3 != mean \n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1e9d282-dd4f-4268-8651-4d13114f8240",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
13 changes: 13 additions & 0 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
@@ -109,6 +109,18 @@ def __init__(
self.optimizer = optimizer
self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}

# lr scheduler
if lr_scheduler is not None and not issubclass(
lr_scheduler, torch.optim.lr_scheduler.LRScheduler
):
raise TypeError(
"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler"
)
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = (
lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
)

# customized by set_configure_optimizers()
self.config_optimizers = None

@@ -430,6 +442,7 @@ def set_configure_optimizers(
"strict": strict,
"name": name,
}

if scheduler is not None and optimizer is not None:
if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler):
raise TypeError(
You are viewing a condensed version of this merge commit. You can view the full changes here.