Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Miscellaneous fixes for FA3 FP8 attention #1174

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4892,6 +4892,10 @@ def forward(
x.transpose(0, 1).contiguous()
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data)
for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
Expand Down Expand Up @@ -5027,24 +5031,40 @@ def forward(
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
fa_optional_forward_kwargs_fp8 = {}
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)

def convert_to_torch_float8(tensor, dtype):
out = torch.Tensor().to(device=tensor.device, dtype=dtype)
out.set_(
tensor._data.untyped_storage(),
tensor._data.storage_offset(),
tensor._data.shape,
tensor._data.stride(),
)
return out

if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward)
for x in [query_layer, key_layer, value_layer]
)
fa_optional_forward_kwargs_fp8["descale_q"] = query_layer._scale_inv
fa_optional_forward_kwargs_fp8["descale_k"] = key_layer._scale_inv
fa_optional_forward_kwargs_fp8["descale_v"] = value_layer._scale_inv
query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
Expand All @@ -5053,6 +5073,7 @@ def forward(
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
**fa_optional_forward_kwargs_fp8,
cyanguwa marked this conversation as resolved.
Show resolved Hide resolved
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
Expand Down Expand Up @@ -5087,8 +5108,12 @@ def forward(
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
output = Float8Tensor.make_like(
output,
data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous(),
)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
Expand Down
Loading