Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient clipping #5150

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
max_norm = torch.tensor([float(max_norm)], device=parameters[0].device)
clip_coef = max_norm / (total_norm + 1e-6)
tmp_tensor = torch.tensor([1.0], device=parameters[0].device)
clip_coef = torch.max(tmp_tensor, clip_coef)
clip_coef = torch.min(tmp_tensor, clip_coef)
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/runtime/test_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_call_to_str():
assert c2s('hello', 1138, val=3) == 'hello(1138, val=3)'


class TestClibGradNorm(DistributedTest):
class TestClipGradNorm(DistributedTest):
world_size = 2

def test(self):
def test_gather(self):
param1 = torch.nn.Parameter(torch.Tensor([0]))
param1.grad = torch.Tensor([1])
param2 = torch.nn.Parameter(torch.Tensor([0]))
Expand All @@ -50,6 +50,27 @@ def test(self):

assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1"

def test_clipped_val(self):
max_norm = 0.1

def test_params():
param1 = torch.nn.Parameter(torch.Tensor([0]))
param1.grad = torch.Tensor([1])
param2 = torch.nn.Parameter(torch.Tensor([0]))
param2.grad = torch.Tensor([1])
return [param1, param2]

# This assumes gradients are same on all the ranks and doesn't consider multiple ranks
params_expected = test_params()
torch.nn.utils.clip_grad_norm_(params_expected, max_norm)

params_actual = test_params()
ds_utils.clip_grad_norm_(params_actual, max_norm=max_norm)

# This can be allclose
assert torch.equal(params_expected[0].grad, params_actual[0].grad)
assert torch.equal(params_expected[1].grad, params_actual[1].grad)


@pytest.mark.parametrize("check_using_norm", [(False), (True)])
class TestCheckOverflow(DistributedTest):
Expand Down
Loading