diff --git a/cneuromax/fitting/deeplearning/litmodule/store.py b/cneuromax/fitting/deeplearning/litmodule/store.py index c3c380a0..c07cf5bc 100644 --- a/cneuromax/fitting/deeplearning/litmodule/store.py +++ b/cneuromax/fitting/deeplearning/litmodule/store.py @@ -1,6 +1,7 @@ """:class:`.BaseLitModule` `Hydra `_ config store.""" from hydra_zen import ZenStore +from schedulefree import AdamWScheduleFree from torch.optim import SGD, Adam, AdamW from transformers import ( get_constant_schedule, @@ -46,6 +47,11 @@ def store_basic_optimizer_configs(store: ZenStore) -> None: """ store(pfs_builds(Adam), name="adam", group="litmodule/optimizer") store(pfs_builds(AdamW), name="adamw", group="litmodule/optimizer") + store( + pfs_builds(AdamWScheduleFree), + name="sfadamw", + group="litmodule/optimizer", + ) store(pfs_builds(SGD), name="sgd", group="litmodule/optimizer") diff --git a/pyproject.toml b/pyproject.toml index faafa7b1..a3d58d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ # OPTIONAL for cneuromax/fitting/deeplearning/ "torchaudio==2.5.1+cu124", # Tensor manipulation on audio data "torchvision==0.20.1+cu124", # Tensor manipulation on vision data + "schedulefree==1.4", # Automatic learning rate scheduler "transformers==4.47.0", # Pre-trained models published on Hugging Face "diffusers==0.31.0", # Diffusion models published on Hugging Face "timm==1.0.12", # Image models