Skip to content

Commit

Permalink
Fix the bug of deepspeed sequence parallel working with batch size la…
Browse files Browse the repository at this point in the history
…rger than 1 (#5823)

Modified the `alltoall` function
Verified the results with only `TP`:

![image](https://github.com/user-attachments/assets/9bdd8942-3565-418f-b7be-614293b2f2f6)

---------

Co-authored-by: Jinghan Yao <[email protected]>
Co-authored-by: Sam Ade Jacobs <[email protected]>
Co-authored-by: Jinghan Yao <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
5 people authored Aug 8, 2024
1 parent ade7149 commit ffe0af2
Showing 1 changed file with 69 additions and 38 deletions.
107 changes: 69 additions & 38 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,76 @@
from deepspeed.accelerator import get_accelerator


def post_all2all(transpose, res_shape):
def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):

def post_func(input):
if transpose:
input = input.transpose(0, 2).contiguous()
input = input.reshape(res_shape)
return input
if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
head_dim).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
return output

return post_func


def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None):
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
head_dim]).contiguous()
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
head_dim]).contiguous()
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()

if scatter_idx < 2:
input_t = input.reshape(
[seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).contiguous()
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
head_dim)
else:
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
input_t = input.reshape(
[-1, seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).transpose(0, 1).contiguous()
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
head_dim)

output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

res_shape=( inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:])
transpose = True if scatter_idx < 2 else False
post_all2all_fun = post_all2all(transpose, res_shape)

if async_op:
if type in ('dq', 'dk'):
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output.view(res_shape)
return output

res = post_all2all_fun(output)
return res
Expand All @@ -67,6 +95,7 @@ def forward(ctx: Any,
input: Tensor,
scatter_idx: int,
gather_idx: int,
batch_dim_idx: int,
stream=None,
handle=None,
type=None,
Expand All @@ -77,39 +106,40 @@ def forward(ctx: Any,
ctx.stream = stream
ctx.handle = handle
ctx.type = type
ctx.batch_dim_idx = batch_dim_idx
if ctx.handle is None:
res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)

else:
# overlap communication path
if not is_fwd and type == 'o':
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
del ctx.stream.activation_buffer_list
# The computation of d o_weight can overlap with the communication of d o_input

elif not is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv
type = 'd' + type
res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type)
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type)

elif is_fwd and type in ('q', 'k'):
# Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v
type = 'fwd_' + type
res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type)
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type)

else:
res = single_all_to_all(input, scatter_idx, gather_idx, group, False)
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)

return res

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:

return (None,
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle,
ctx.type, False), None, None, None, None, None, None)
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)


class DistributedAttention(torch.nn.Module):
Expand Down Expand Up @@ -148,13 +178,14 @@ def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.dafult_stream.wait_event(layer.done_event)

def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
batch_dim_idx (int): indicating which dim is batch
args: other args
Returns:
Expand All @@ -179,15 +210,15 @@ def pre_hook_fun(grad):
return pre_hook_fun

self.layer_sync(query)
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None,
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'q')
self.layer_sync(key)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles,
'k')
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'k')
if self.sp_overlap_comm:
self.dafult_stream.wait_stream(self.sp_stream)

value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None,
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'v')

if self.sp_overlap_comm:
Expand All @@ -205,8 +236,8 @@ def pre_hook_fun(grad):

context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream,
self.overlap_handles, 'o')
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx,
self.sp_stream, self.overlap_handles, 'o')

#out e.g., [s/p::h]
return output

0 comments on commit ffe0af2

Please sign in to comment.