Skip to content

Commit

Permalink
early all-reduce total_norm in non-PP grad norm clipping
Browse files Browse the repository at this point in the history
ghstack-source-id: cf1729cce656e17f1e3db5a8eb33dcd2d284a3d0
Pull Request resolved: #769
  • Loading branch information
tianyu-l committed Jan 2, 2025
1 parent 3f20451 commit d0434a3
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,18 @@ def clip_grad_norm_(
grads, norm_type, error_if_nonfinite, foreach
)

if pp_mesh is not None:
if isinstance(total_norm, DTensor):
# will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor

# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
# we can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.full_tensor()
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor.
# NOTE: It has two purposes:
# 1. to make sure the total norm is computed correctly when PP is used (see below)
# 2. to return a reduced total_norm tensor whose .item() would return the correct value
if isinstance(total_norm, DTensor):
# Will reach here if any non-PP parallelism is used.
# If only using PP, total_norm will be a local tensor.
total_norm = total_norm.full_tensor()

# TODO: cleanup maybe using DTensor
if pp_mesh is not None:
if math.isinf(norm_type):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
else:
Expand Down

0 comments on commit d0434a3

Please sign in to comment.