From db1495b0b36403b5e5919ce5983a39434b17cbb0 Mon Sep 17 00:00:00 2001 From: liangyuwang Date: Tue, 6 Aug 2024 06:11:26 +0300 Subject: [PATCH] fix bugs of communication --- tiny_deepspeed/core/zero/ddp/module.py | 5 ++++- tiny_deepspeed/core/zero/zero1/module.py | 6 ++++-- tiny_deepspeed/core/zero/zero2/module.py | 16 +++++++++------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tiny_deepspeed/core/zero/ddp/module.py b/tiny_deepspeed/core/zero/ddp/module.py index 4b7de49..3bbdf9a 100644 --- a/tiny_deepspeed/core/zero/ddp/module.py +++ b/tiny_deepspeed/core/zero/ddp/module.py @@ -16,9 +16,12 @@ def sync_grad(grad, async_op=True): # communication complexity: 2g if async_op: - return dist.all_reduce(grad, async_op=True) + work = dist.all_reduce(grad, async_op=True) else: dist.all_reduce(grad, async_op=False) + work = None + torch.cuda.synchronize() + return work class Linear(linear.Linear): diff --git a/tiny_deepspeed/core/zero/zero1/module.py b/tiny_deepspeed/core/zero/zero1/module.py index 498badc..8c17008 100644 --- a/tiny_deepspeed/core/zero/zero1/module.py +++ b/tiny_deepspeed/core/zero/zero1/module.py @@ -16,10 +16,12 @@ def sync_grad(grad, async_op=True, rank_id=None): # communication complexity: g if async_op: - return dist.reduce(grad, dst=rank_id, async_op=True) + work = dist.reduce(grad, dst=rank_id, async_op=True) else: dist.reduce(grad, dst=rank_id, async_op=False) - return None + work = None + torch.cuda.synchronize() + return work class Linear(linear.Linear): diff --git a/tiny_deepspeed/core/zero/zero2/module.py b/tiny_deepspeed/core/zero/zero2/module.py index 9e62587..208d456 100644 --- a/tiny_deepspeed/core/zero/zero2/module.py +++ b/tiny_deepspeed/core/zero/zero2/module.py @@ -16,10 +16,12 @@ def sync_grad(grad, async_op=True, rank_id=None): # communication complexity: g if async_op: - return dist.reduce(grad, dst=rank_id, async_op=True) + work = dist.reduce(grad, dst=rank_id, async_op=True) else: dist.reduce(grad, dst=rank_id, async_op=False) - return None + work = None + torch.cuda.synchronize() + return work def desync_grad(grad, rank_id=None): if grad is not None and rank_id is not None: @@ -27,11 +29,11 @@ def desync_grad(grad, rank_id=None): # print(dist.get_rank(), rank_id) grad.data = torch.randn(1, device=grad.device, dtype=grad.dtype) grad.data.to("cpu") # should actually be released but impossible in pytorch, maybe solved by plugin C++ - torch.cuda.synchronize() - return None - else: - return grad - return grad + grad = None + torch.cuda.synchronize() + return grad + else: + return None class Linear(linear.Linear):