Skip to content

Commit

Permalink
Minor changes based on review comments.
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Kunlun Li <[email protected]>
  • Loading branch information
2 people authored and kunlunl committed Oct 31, 2024
1 parent d05cc42 commit 051e94b
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 78 deletions.
149 changes: 107 additions & 42 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,21 @@ def test_frozen_model(self):

torch.testing.assert_close(ref_param, tst_param)

def gen_precision_aware_test(self, use_fp8_params, param_dtype, use_master_weights,
master_weight_dtype, grad_dtype, exp_avg_dtype, exp_avg_sq_dtype,
model_rtol, model_atol, master_rtol=None, master_atol=None):
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
Expand All @@ -202,8 +214,8 @@ def gen_precision_aware_test(self, use_fp8_params, param_dtype, use_master_weigh
model_params.append(p)

options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
Expand All @@ -229,14 +241,23 @@ def test_one_iteration(ref_optimizer, tst_optimizer):
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params, master_weights_to_fp32, rtol=master_rtol, atol=master_atol,
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=model_rtol, atol=model_atol, equal_nan=True,
)

for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
Expand All @@ -258,66 +279,110 @@ def test_one_iteration(ref_optimizer, tst_optimizer):

def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=False,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
model_rtol=2e-3, model_atol=2e-3,
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1.3e-6, master_atol=1e-5,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.half, grad_dtype=torch.float32,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1e-3, master_atol=1e-3,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1.3e-6, master_atol=1e-5,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.half, exp_avg_sq_dtype=torch.float32,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1.3e-6, master_atol=1e-5,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8, exp_avg_sq_dtype=torch.float32,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1e-3, master_atol=1e-3,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.half,
model_rtol=1e-3, model_atol=1e-3, master_rtol=1.3e-6, master_atol=1e-5,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False, param_dtype=torch.bfloat16, use_master_weights=True,
master_weight_dtype=torch.float32, grad_dtype=torch.float32,
exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.uint8,
model_rtol=5e-2, model_atol=5e-2, master_rtol=5e-2, master_atol=5e-2,
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
Expand All @@ -531,8 +531,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
Expand Down
69 changes: 37 additions & 32 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@ class FusedAdam(torch.optim.Optimizer):
in the optimizer with FP16/BF16 mixed precision training.
(default: False)
master_weight_dtype (torch.dtype, optional): The dtype of master weights.
If master_weights is False, this will be ignored.
If master_weights is False, this will be ignored. It can be one of
[torch.float32, torch.float16]. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg.
exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg. It can be
one of [torch.float32, torch.float16, torch.uint8], where torch.uint8
represents FP8. If it's not torch.float32, the optimizer will create
a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq.
exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq. It
can be one of [torch.float32, torch.float16, torch.uint8], where
torch.uint8 represents FP8. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
use_decoupled_grad (bool, optional): Whether to use ".decoupled_grad"
instead of ".grad" for reading gradients. It's useful when the dtypes
Expand Down Expand Up @@ -116,11 +124,11 @@ def __init__(
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")

# Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.half]:
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.half, torch.uint8]:
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.half, torch.uint8]:
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.")

# Currently, capturable mode only supports fp32 master weights and optimizer states.
Expand Down Expand Up @@ -177,7 +185,9 @@ def __init__(
"master_param": self.master_weight_dtype,
}
self.dtype_to_range_map = {
torch.half: torch.full([1], torch.finfo(torch.half).max / 2.0, dtype=torch.float32),
torch.float16: torch.full(
[1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32
),
torch.uint8: torch.full([1], 448.0, dtype=torch.float32),
}
self._scales = {}
Expand Down Expand Up @@ -245,17 +255,17 @@ def get_unscaled_state(self, param, state_name):
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(state[state_name], Float8Tensor)
return state[state_name].float()
elif dtype == torch.half:
assert state[state_name].dtype == torch.half
unscaled = state[state_name].float()
elif dtype == torch.float16:
assert state[state_name].dtype == torch.float16
unscaled = state[state_name].float()
unscaled.mul_(self._scales[param][state_name])
return unscaled
elif dtype == torch.float32:
assert state[state_name].dtype == torch.float32
return state[state_name]
unscaled = state[state_name]
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
return unscaled

def set_scaled_state(self, param, state_name, unscaled_state):
"""Set the optimizer state.
Expand All @@ -271,27 +281,15 @@ def set_scaled_state(self, param, state_name, unscaled_state):
"""
assert unscaled_state.dtype == torch.float32
state = self.state[param]
if param not in self._scales:
self._scales[param] = {}
scale = self._scales[param]
if state_name not in state:
self._initialize_state(param, state_name, False)

dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
if state_name not in state:
state[state_name] = Float8Tensor.to_float8(torch.empty_like(param.data).float())
scale[state_name] = torch.ones([1], device=param.device)
self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
elif dtype == torch.half:
if state_name not in state:
state[state_name] = torch.empty_like(param.data).half()
scale[state_name] = torch.ones([1], device=param.device)
if dtype != torch.float32:
scale = self._scales[param]
self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
elif dtype == torch.float32:
if state_name not in state:
state[state_name] = torch.empty_like(param.data).float()
state[state_name].copy_(unscaled_state)
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
state[state_name].copy_(unscaled_state)

def _initialize_state(self, param, state_name, zero_buffer: bool):
"""Initialize one of the optimizer states according to `state_name`.
Expand All @@ -302,12 +300,19 @@ def _initialize_state(self, param, state_name, zero_buffer: bool):
and 'master_param`.
zero_buffer (bool): Whether to initialize the optimizer state with zeros.
"""
buffer = torch.zeros_like(param).float() if zero_buffer else torch.empty_like(param).float()
dtype = self.name_to_dtype_map[state_name]
data = torch.empty_like(param, dtype=dtype)
if zero_buffer:
data.zero_()

if dtype == torch.uint8:
self.state[param][state_name] = Float8Tensor.to_float8(buffer)
self.state[param][state_name] = Float8Tensor(
data=data,
dtype=torch.float32,
fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device),
)
else:
self.state[param][state_name] = buffer.to(dtype)
self.state[param][state_name] = data

# Create scale if necessary.
if dtype != torch.float32:
Expand Down

0 comments on commit 051e94b

Please sign in to comment.