@@ -12,22 +12,25 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
1212 amsgrad = False , adam_w_mode = True ,
1313 grad_averaging = True , max_grad_norm = 1.0 , use_nvlamb = False ,
1414 reduced_precision_dtype = None ):
15+
1516 if amsgrad :
1617 raise RuntimeError ('FusedLAMB does not support the AMSGrad variant.' )
17-
18- # The learning rate (lr) and optimizer step (step) should be located on device
19- # in order to faciliated device sync free execution
18+
19+ # init defaults
2020 defaults = dict (lr = torch .tensor (lr , dtype = torch .float32 ),
2121 step = torch .tensor ([step ], dtype = torch .int ),
2222 bias_correction = bias_correction ,
2323 betas = betas , eps = eps , weight_decay = weight_decay ,
2424 grad_averaging = grad_averaging ,
2525 max_grad_norm = max_grad_norm )
26- tensor_state = ['lr' , 'step' ]
26+
27+ # init base module
2728 super (FusedMixedPrecisionLamb , self ).__init__ (params , defaults )
2829
30+ # The learning rate (lr) and optimizer step (step) should be located on device
31+ # in order to faciliated device sync free execution
2932 device = self .param_groups [0 ]['params' ][0 ].device
30-
33+ tensor_state = [ 'lr' , 'step' ]
3134 for idx ,group in enumerate (self .param_groups ):
3235 for item in tensor_state :
3336 self .param_groups [idx ][item ] = group [item ].to (device = device )
0 commit comments