diff --git a/op_builder/npu/fused_adam.py b/op_builder/npu/fused_adam.py index fc1bc83c7cc7..d32103db7055 100644 --- a/op_builder/npu/fused_adam.py +++ b/op_builder/npu/fused_adam.py @@ -16,8 +16,8 @@ class NPUFusedAdam: @staticmethod def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, bias_correction, weight_decay, *args): - bias_correction1 = beta1**step - bias_correction2 = beta2**step + bias_correction1 = beta1**(step - 1) + bias_correction2 = beta2**(step - 1) # iteration group['params'] for i in range(len(tensor_lists[0])):