Skip to content

Commit

Permalink
support dtype casting in fused adam
Browse files Browse the repository at this point in the history
Signed-off-by: Shijie Wang <[email protected]>
  • Loading branch information
Wong4j committed Jul 1, 2024
1 parent 086a12f commit aa11601
Show file tree
Hide file tree
Showing 5 changed files with 1,041 additions and 768 deletions.
101 changes: 88 additions & 13 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch import fp8_autocast, fp8_model_init


class TestFusedOptimizer(unittest.TestCase):
Expand Down Expand Up @@ -117,9 +120,7 @@ def test_multi_params(self):
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)

for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
Expand All @@ -139,9 +140,7 @@ def test_adam_option(self):
}

tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)

for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
Expand All @@ -161,9 +160,7 @@ def test_frozen_model(self):
}

tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)

# Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []})
Expand All @@ -175,10 +172,88 @@ def test_frozen_model(self):

torch.testing.assert_close(ref_param, tst_param)

def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(master_params, extra_param_out=True, **options)
tst_optim.param_groups[0]["extra_params"] = model_params

for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True
)

def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(master_params, extra_param_out=True, **options)
tst_optim.param_groups[0]["extra_params"] = model_params

for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True
)


class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125}
self.options = {"lr": 0.25, "momentum": 0.125}
self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD

Expand All @@ -188,7 +263,7 @@ def test_float(self):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)

@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
@unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
Expand Down Expand Up @@ -452,8 +527,8 @@ def testNative(self):

@largeTensorTest("60GB", "cuda")
def testLargeTensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t)
t.grad = grad
t2.grad = grad
Expand Down
Loading

0 comments on commit aa11601

Please sign in to comment.