From 0f84f8f39f021fe680616d94808d8b6464270835 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 28 Aug 2024 17:13:52 -0700 Subject: [PATCH] add regenerative regularization and cite --- README.md | 9 +++++++++ q_transformer/optimizer.py | 11 ++++------- q_transformer/q_learner.py | 2 ++ setup.py | 3 ++- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 3c6e0ef..21a638a 100644 --- a/README.md +++ b/README.md @@ -151,3 +151,12 @@ actions = model.get_optimal_actions(video, instructions) year = {2022} } ``` + +```bibtex +@inproceedings{Kumar2023MaintainingPI, + title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization}, + author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy}, + year = {2023}, + url = {https://api.semanticscholar.org/CorpusID:261076021} +} +``` diff --git a/q_transformer/optimizer.py b/q_transformer/optimizer.py index b50e24a..463f2bc 100644 --- a/q_transformer/optimizer.py +++ b/q_transformer/optimizer.py @@ -1,4 +1,4 @@ -from torch.optim import AdamW, Adam +from adam_atan2_pytorch import AdamAtan2 def separate_weight_decayable_params(params): wd_params, no_wd_params = [], [] @@ -10,9 +10,9 @@ def separate_weight_decayable_params(params): def get_adam_optimizer( params, lr = 1e-4, - wd = 1e-2, + wd = 0, betas = (0.9, 0.99), - eps = 1e-8, + regen_reg_rate = 1e-2, filter_by_requires_grad = False, group_wd_params = True ): @@ -29,7 +29,4 @@ def get_adam_optimizer( {'params': no_wd_params, 'weight_decay': 0}, ] - if not has_wd: - return Adam(params, lr = lr, betas = betas, eps = eps) - - return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) + return AdamAtan2(params, lr = lr, weight_decay = wd, betas = betas, regen_reg_rate = regen_reg_rate) diff --git a/q_transformer/q_learner.py b/q_transformer/q_learner.py index e4a44db..b24970f 100644 --- a/q_transformer/q_learner.py +++ b/q_transformer/q_learner.py @@ -93,6 +93,7 @@ def __init__( grad_accum_every: int = 1, monte_carlo_return: float | None = None, weight_decay: float = 0., + regen_reg_rate: float = 1e-3, accelerator: Accelerator | None = None, accelerator_kwargs: dict = dict(), dataloader_kwargs: dict = dict( @@ -143,6 +144,7 @@ def __init__( model.parameters(), lr = learning_rate, wd = weight_decay, + regen_reg_rate = regen_reg_rate, **optimizer_kwargs ) diff --git a/setup.py b/setup.py index 100d5ee..7da4c1d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.1.17', + version = '0.2.0', license='MIT', description = 'Q-Transformer', author = 'Phil Wang', @@ -19,6 +19,7 @@ ], install_requires=[ 'accelerate', + 'adam-atan2-pytorch>=0.0.12', 'beartype', 'classifier-free-guidance-pytorch>=0.6.10', 'einops>=0.8.0',