diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e138f946e8..296b68af8a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5110,10 +5110,10 @@ def convert_to_torch_float8(tensor, dtype): if fp8 and fp8_meta["recipe"].fp8_mha: output = Float8Tensor.make_like( output, - data=output._data.reshape( - batch_size, max_seqlen_q // cp_size, -1 - ).transpose(0,1).contiguous() - ) + data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous(), + ) else: output = ( output.view(batch_size, max_seqlen_q // cp_size, -1)