Skip to content

Commit

Permalink
update all_to_all (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lay2000 authored Oct 14, 2024
1 parent 81607c8 commit 582a9f9
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions yunchang/comm/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def all_to_all_4D(
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
dist.all_to_all_single(output, input_t, group=group)

if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, bs, shard_hc, hs)

Expand Down Expand Up @@ -80,7 +83,11 @@ def all_to_all_4D(
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
dist.all_to_all_single(output, input_t, group=group)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t

# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, bs, hs)
Expand Down Expand Up @@ -162,7 +169,11 @@ def all_to_all_5D(
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head
dist.all_to_all_single(output, input_t, group=group)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t

# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, 3, bs, shard_hc, hs)
Expand Down Expand Up @@ -191,7 +202,11 @@ def all_to_all_5D(
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
dist.all_to_all_single(output, input_t, group=group)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t

# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, 3, bs, hs)
Expand Down

0 comments on commit 582a9f9

Please sign in to comment.