Skip to content

Commit

Permalink
Merge pull request #354 from kozistr/fix/schedulefreeradam-optimizer
Browse files Browse the repository at this point in the history
[Fix] bias correction2 in ScheduleFreeRAdam optimizer
  • Loading branch information
kozistr authored Feb 23, 2025
2 parents c09d18b + ad60eb0 commit 3d20627
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 4 deletions.
6 changes: 6 additions & 0 deletions docs/changelogs/v3.4.3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### Change Log


### Fix

* bias_correction2 in ScheduleFreeRAdam optimizer. (#354)
Binary file modified docs/visualizations/rastrigin_ScheduleFreeRAdam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/visualizations/rosenbrock_ScheduleFreeRAdam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 5 additions & 2 deletions examples/visualize_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
'dadaptlion': {'lr': hp.uniform('lr', 0, 10)},
'padam': {'lr': hp.uniform('lr', 0, 10)},
'dadaptadam': {'lr': hp.uniform('lr', 0, 10)},
'adahessian': {'lr': hp.uniform('lr', 0, 800)}, # Wider range for second-order optimizers
'sophiah': {'lr': hp.uniform('lr', 0, 60)}, # Wider range for second-order optimizers
'adahessian': {'lr': hp.uniform('lr', 0, 800)},
'sophiah': {'lr': hp.uniform('lr', 0, 60)},
'pid': {
'lr': hp.uniform('lr', 0, 0.5),
'derivative': hp.quniform('derivative', 2, 14, 0.5),
Expand Down Expand Up @@ -86,6 +86,9 @@
'lr': hp.uniform('lr', 0, 3),
'momentum': hp.quniform('momentum', 0, 0.99, 0.01),
},
'schedulefreeradam': {
'lr': hp.uniform('lr', 1, 10),
},
'kron': {
'lr': hp.uniform('lr', 0, 0.8),
'momentum': hp.quniform('momentum', 0, 0.99, 0.01),
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction2: float = self.debias_beta(beta2, group['step'])
bias_correction2: float = self.debias(beta2, group['step'])

lr, n_sma = self.get_rectify_step_size(
is_rectify=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeRAdam, {'lr': 1e0}, 20),
(ScheduleFreeRAdam, {'lr': 1e2, 'weight_decay': 1e-3}, 20),
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 5e0, 'weight_decay': 1e-3, 'grokfast_after_step': 1}, 5),
(Kate, {'lr': 5e-2}, 10),
Expand Down

0 comments on commit 3d20627

Please sign in to comment.