Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add example of modifying the default configure_optimizers() behavior (use of ReduceLROnPlateau scheduler) #1015

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bf2e11d
Review: Add warning if users pass only optimizer_kwargs but not optim…
JQGoh May 23, 2024
7d83799
option to modify configure_optimizers
JQGoh May 23, 2024
ce6848c
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Aug 2, 2024
3539eb3
Add example on ReduceLROnPlateau
JQGoh Aug 2, 2024
c0a7eb8
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Dec 11, 2024
2e3d7c6
Remove old interface and deprecate the arguments
JQGoh Dec 11, 2024
eb17a4d
Fix test
JQGoh Dec 11, 2024
f5a4f62
Correction to the path in contributing.md note
JQGoh Dec 18, 2024
829d80a
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Dec 21, 2024
3bab7c0
Review: Allow users to specifify the configure_optimizers() for
JQGoh Dec 22, 2024
c5f20a1
Add arguments and doc string
JQGoh Dec 22, 2024
a817907
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Jan 21, 2025
a5ca43e
Merge branch 'main' into feat/modify-config-optimizers
JQGoh Jan 27, 2025
5fb6331
Revert "Add arguments and doc string"
JQGoh Jan 28, 2025
3ca378a
Revert "Review: Allow users to specifify the configure_optimizers() for"
JQGoh Jan 28, 2025
5c8e1c2
Revert "Fix test"
JQGoh Jan 28, 2025
cc317cf
Revert "Remove old interface and deprecate the arguments"
JQGoh Jan 28, 2025
0266daf
Revert "Add example on ReduceLROnPlateau"
JQGoh Jan 28, 2025
5765c20
Revert "option to modify configure_optimizers"
JQGoh Jan 28, 2025
0d07a7a
Omit unnecessary changes
JQGoh Jan 28, 2025
684184f
Add documentation on customizing configure_optimizers()
JQGoh Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Revert "Add example on ReduceLROnPlateau"
This reverts commit 3539eb3.
JQGoh committed Jan 28, 2025
commit 0266daff57d61fac2d2dc27629a02a231fb53912
64 changes: 0 additions & 64 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
@@ -3359,70 +3359,6 @@
")\n",
"assert all([col in cv2.columns for col in ['NHITS-lo-30', 'NHITS-hi-30']])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f0ca124",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test customized lr_scheduler behavior such that the user defined lr_scheduler result should differ from default\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 2}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" nf.fit(AirPassengersPanel_train)\n",
" default_predict = nf.predict()\n",
" mean = default_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" # calling set_configure_optimizers() shall modify the default behavior of configure_optimizers()\n",
" optimizer = torch.optim.Adadelta(params=models[0].parameters(), rho=0.45)\n",
" scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer=optimizer, factor=0.78)\n",
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
"\n",
" )\n",
" nf2 = NeuralForecast(models=models, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
" customized_predict = nf2.predict()\n",
" mean2 = customized_predict.loc[:, nf_model.__name__].mean()\n",
" assert mean2 != mean\n",
"\n",
" # test that frequency configured has effect on optimization behavior\n",
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" frequency=2,\n",
" )\n",
" nf3 = NeuralForecast(models=models, freq='M')\n",
" nf3.fit(AirPassengersPanel_train)\n",
" customized_predict3 = nf3.predict()\n",
" mean3 = customized_predict3.loc[:, nf_model.__name__].mean()\n",
" assert mean3 != mean\n",
"\n",
" # test ReduceLROnPlateau\n",
" scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
" optimizer, mode='min', factor=0.5, patience=2,\n",
" )\n",
" \n",
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" monitor=\"train_loss\",\n",
" )\n",
" nf4 = NeuralForecast(models=models, freq='M')\n",
" nf4.fit(AirPassengersPanel_train)\n",
" customized_predict4 = nf4.predict()\n",
" mean4 = customized_predict4.loc[:, nf_model.__name__].mean()\n",
" assert mean4 != mean \n"
]
}
],
"metadata": {