Skip to content

Commit

Permalink
[PyTorch] Relax the contiguous check for flash attention (#1176)
Browse files Browse the repository at this point in the history
* relax contiguous check for flash attention

Signed-off-by: Xin Yao <[email protected]>

* force contiguous for cp

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Sep 19, 2024
1 parent c0caadb commit 0ee5ccd
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4881,20 +4881,19 @@ def forward(
)
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous()
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
Expand Down Expand Up @@ -5092,19 +5091,15 @@ def forward(
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
# (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd":
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)

return output
return output.contiguous()


def _combine_tensors(
Expand Down

0 comments on commit 0ee5ccd

Please sign in to comment.