1
+ from collections import defaultdict
1
2
from typing import Callable , Dict
2
3
3
4
import torch
4
5
from torch .optim import Optimizer
5
6
6
7
from pytorch_optimizer .base .optimizer import BaseOptimizer
7
- from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , OPTIMIZER_INSTANCE_OR_CLASS
8
+ from pytorch_optimizer .base .types import (
9
+ CLOSURE ,
10
+ DEFAULTS ,
11
+ LOSS ,
12
+ OPTIMIZER_INSTANCE_OR_CLASS ,
13
+ STATE ,
14
+ )
8
15
9
16
10
17
class OrthoGrad (BaseOptimizer ):
@@ -20,25 +27,29 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
20
27
self ._optimizer_step_post_hooks : Dict [int , Callable ] = {}
21
28
self .eps : float = 1e-30
22
29
30
+ self .state : STATE = defaultdict (dict )
31
+
23
32
if isinstance (optimizer , Optimizer ):
24
33
self .optimizer = optimizer
25
- elif ' params' in kwargs :
26
- params = kwargs .pop (' params' )
34
+ elif " params" in kwargs :
35
+ params = kwargs .pop (" params" )
27
36
self .optimizer = optimizer (params , ** kwargs )
28
37
else :
29
- raise ValueError ('Need to pass `params` when you pass the torch.optim.Optimizer instance.' )
38
+ raise ValueError (
39
+ "Need to pass `params` when you pass the torch.optim.Optimizer instance."
40
+ )
30
41
31
42
self .defaults : DEFAULTS = self .optimizer .defaults
32
43
33
44
def __str__ (self ) -> str :
34
- return ' OrthoGrad'
45
+ return " OrthoGrad"
35
46
36
47
@property
37
48
def param_groups (self ):
38
49
return self .optimizer .param_groups
39
50
40
51
def __getstate__ (self ):
41
- return {' optimizer' : self .optimizer }
52
+ return {" optimizer" : self .optimizer }
42
53
43
54
@torch .no_grad ()
44
55
def reset (self ):
@@ -55,12 +66,14 @@ def orthogonalize_gradients(self, params) -> None:
55
66
56
67
proj = torch .dot (w , g ).div_ (torch .dot (w , w ).add_ (self .eps ))
57
68
g_ortho = g .to (dtype = torch .float32 , copy = True ).sub_ (w , alpha = proj )
58
- g_ortho_scaled = g_ortho .mul_ (g .norm (2 ).div_ (g_ortho .norm (2 ).add_ (self .eps )))
69
+ g_ortho_scaled = g_ortho .mul_ (
70
+ g .norm (2 ).div_ (g_ortho .norm (2 ).add_ (self .eps ))
71
+ )
59
72
60
73
p .grad .copy_ (g_ortho_scaled .view_as (p .grad ))
61
74
62
75
@torch .no_grad ()
63
76
def step (self , closure : CLOSURE = None ) -> LOSS :
64
77
for group in self .param_groups :
65
- self .orthogonalize_gradients (group [' params' ])
78
+ self .orthogonalize_gradients (group [" params" ])
66
79
return self .optimizer .step (closure )
0 commit comments