Skip to content

Commit 30a7ad3

Browse files
authored
Tkurth/mplamb fixed (#1684)
1 parent 2d8302a commit 30a7ad3

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

apex/optimizers/fused_mixed_precision_lamb.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,25 @@ def __init__(self, params, lr=1e-3, step=0, bias_correction=True,
1212
amsgrad=False, adam_w_mode=True,
1313
grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False,
1414
reduced_precision_dtype=None):
15+
1516
if amsgrad:
1617
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
17-
18-
# The learning rate (lr) and optimizer step (step) should be located on device
19-
# in order to faciliated device sync free execution
18+
19+
# init defaults
2020
defaults = dict(lr=torch.tensor(lr, dtype=torch.float32),
2121
step=torch.tensor([step], dtype=torch.int),
2222
bias_correction=bias_correction,
2323
betas=betas, eps=eps, weight_decay=weight_decay,
2424
grad_averaging=grad_averaging,
2525
max_grad_norm=max_grad_norm)
26-
tensor_state = ['lr', 'step']
26+
27+
# init base module
2728
super(FusedMixedPrecisionLamb, self).__init__(params, defaults)
2829

30+
# The learning rate (lr) and optimizer step (step) should be located on device
31+
# in order to faciliated device sync free execution
2932
device = self.param_groups[0]['params'][0].device
30-
33+
tensor_state = ['lr', 'step']
3134
for idx,group in enumerate(self.param_groups):
3235
for item in tensor_state:
3336
self.param_groups[idx][item] = group[item].to(device=device)

0 commit comments

Comments
 (0)