Skip to content

Commit

Permalink
fix bugs of communication
Browse files Browse the repository at this point in the history
  • Loading branch information
liangyuwang committed Aug 6, 2024
1 parent eb66368 commit db1495b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
5 changes: 4 additions & 1 deletion tiny_deepspeed/core/zero/ddp/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

def sync_grad(grad, async_op=True): # communication complexity: 2g
if async_op:
return dist.all_reduce(grad, async_op=True)
work = dist.all_reduce(grad, async_op=True)
else:
dist.all_reduce(grad, async_op=False)
work = None
torch.cuda.synchronize()
return work


class Linear(linear.Linear):
Expand Down
6 changes: 4 additions & 2 deletions tiny_deepspeed/core/zero/zero1/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

def sync_grad(grad, async_op=True, rank_id=None): # communication complexity: g
if async_op:
return dist.reduce(grad, dst=rank_id, async_op=True)
work = dist.reduce(grad, dst=rank_id, async_op=True)
else:
dist.reduce(grad, dst=rank_id, async_op=False)
return None
work = None
torch.cuda.synchronize()
return work


class Linear(linear.Linear):
Expand Down
16 changes: 9 additions & 7 deletions tiny_deepspeed/core/zero/zero2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@

def sync_grad(grad, async_op=True, rank_id=None): # communication complexity: g
if async_op:
return dist.reduce(grad, dst=rank_id, async_op=True)
work = dist.reduce(grad, dst=rank_id, async_op=True)
else:
dist.reduce(grad, dst=rank_id, async_op=False)
return None
work = None
torch.cuda.synchronize()
return work

def desync_grad(grad, rank_id=None):
if grad is not None and rank_id is not None:
if dist.get_rank() != rank_id:
# print(dist.get_rank(), rank_id)
grad.data = torch.randn(1, device=grad.device, dtype=grad.dtype)
grad.data.to("cpu") # should actually be released but impossible in pytorch, maybe solved by plugin C++
torch.cuda.synchronize()
return None
else:
return grad
return grad
grad = None
torch.cuda.synchronize()
return grad
else:
return None


class Linear(linear.Linear):
Expand Down

0 comments on commit db1495b

Please sign in to comment.