Skip to content

Commit

Permalink
add regenerative regularization and cite
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 29, 2024
1 parent 0b5c5ab commit 0f84f8f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,12 @@ actions = model.get_optimal_actions(video, instructions)
year = {2022}
}
```

```bibtex
@inproceedings{Kumar2023MaintainingPI,
title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:261076021}
}
```
11 changes: 4 additions & 7 deletions q_transformer/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.optim import AdamW, Adam
from adam_atan2_pytorch import AdamAtan2

def separate_weight_decayable_params(params):
wd_params, no_wd_params = [], []
Expand All @@ -10,9 +10,9 @@ def separate_weight_decayable_params(params):
def get_adam_optimizer(
params,
lr = 1e-4,
wd = 1e-2,
wd = 0,
betas = (0.9, 0.99),
eps = 1e-8,
regen_reg_rate = 1e-2,
filter_by_requires_grad = False,
group_wd_params = True
):
Expand All @@ -29,7 +29,4 @@ def get_adam_optimizer(
{'params': no_wd_params, 'weight_decay': 0},
]

if not has_wd:
return Adam(params, lr = lr, betas = betas, eps = eps)

return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
return AdamAtan2(params, lr = lr, weight_decay = wd, betas = betas, regen_reg_rate = regen_reg_rate)
2 changes: 2 additions & 0 deletions q_transformer/q_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
grad_accum_every: int = 1,
monte_carlo_return: float | None = None,
weight_decay: float = 0.,
regen_reg_rate: float = 1e-3,
accelerator: Accelerator | None = None,
accelerator_kwargs: dict = dict(),
dataloader_kwargs: dict = dict(
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
model.parameters(),
lr = learning_rate,
wd = weight_decay,
regen_reg_rate = regen_reg_rate,
**optimizer_kwargs
)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.1.17',
version = '0.2.0',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -19,6 +19,7 @@
],
install_requires=[
'accelerate',
'adam-atan2-pytorch>=0.0.12',
'beartype',
'classifier-free-guidance-pytorch>=0.6.10',
'einops>=0.8.0',
Expand Down

0 comments on commit 0f84f8f

Please sign in to comment.