Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Dec 11, 2024
2 parents ddcd2ab + ecb4bf3 commit 2e19cfa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
# we only need num_heads once
num_heads = input.shape[2]

if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and not scatter_idx < 2):
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
Expand Down

0 comments on commit 2e19cfa

Please sign in to comment.