Skip to content

Commit

Permalink
Merge branch 'main' into add_descales
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Sep 19, 2024
2 parents 39a4e1d + 0ee5ccd commit 5bcc355
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4888,24 +4888,23 @@ def forward(
)
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous()
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
Expand Down Expand Up @@ -5140,19 +5139,15 @@ def convert_to_torch_float8(tensor, dtype):
.contiguous(),
)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
# (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd":
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)

return output
return output.contiguous()


def _combine_tensors(
Expand Down

0 comments on commit 5bcc355

Please sign in to comment.