Skip to content

Commit 83204c0

Browse files
committed
fix state management bug
1 parent 55c3553 commit 83204c0

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

pytorch_optimizer/optimizer/orthograd.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
from collections import defaultdict
12
from typing import Callable, Dict
23

34
import torch
45
from torch.optim import Optimizer
56

67
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
8+
from pytorch_optimizer.base.types import (
9+
CLOSURE,
10+
DEFAULTS,
11+
LOSS,
12+
OPTIMIZER_INSTANCE_OR_CLASS,
13+
STATE,
14+
)
815

916

1017
class OrthoGrad(BaseOptimizer):
@@ -20,25 +27,29 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
2027
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
2128
self.eps: float = 1e-30
2229

30+
self.state: STATE = defaultdict(dict)
31+
2332
if isinstance(optimizer, Optimizer):
2433
self.optimizer = optimizer
25-
elif 'params' in kwargs:
26-
params = kwargs.pop('params')
34+
elif "params" in kwargs:
35+
params = kwargs.pop("params")
2736
self.optimizer = optimizer(params, **kwargs)
2837
else:
29-
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
38+
raise ValueError(
39+
"Need to pass `params` when you pass the torch.optim.Optimizer instance."
40+
)
3041

3142
self.defaults: DEFAULTS = self.optimizer.defaults
3243

3344
def __str__(self) -> str:
34-
return 'OrthoGrad'
45+
return "OrthoGrad"
3546

3647
@property
3748
def param_groups(self):
3849
return self.optimizer.param_groups
3950

4051
def __getstate__(self):
41-
return {'optimizer': self.optimizer}
52+
return {"optimizer": self.optimizer}
4253

4354
@torch.no_grad()
4455
def reset(self):
@@ -55,12 +66,14 @@ def orthogonalize_gradients(self, params) -> None:
5566

5667
proj = torch.dot(w, g).div_(torch.dot(w, w).add_(self.eps))
5768
g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
58-
g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(self.eps)))
69+
g_ortho_scaled = g_ortho.mul_(
70+
g.norm(2).div_(g_ortho.norm(2).add_(self.eps))
71+
)
5972

6073
p.grad.copy_(g_ortho_scaled.view_as(p.grad))
6174

6275
@torch.no_grad()
6376
def step(self, closure: CLOSURE = None) -> LOSS:
6477
for group in self.param_groups:
65-
self.orthogonalize_gradients(group['params'])
78+
self.orthogonalize_gradients(group["params"])
6679
return self.optimizer.step(closure)

0 commit comments

Comments
 (0)