Skip to content

Commit

Permalink
Provide FLOPs formula for FlashAttention3
Browse files Browse the repository at this point in the history
ghstack-source-id: 5eb3152cb965ccb4a6227e50f026539ca5bccd76
Pull Request resolved: fairinternal/xformers#1281

__original_commit__ = fairinternal/xformers@02f4d4b
  • Loading branch information
lw authored and xFormers Bot committed Jan 16, 2025
1 parent 08cc74d commit 536363e
Showing 1 changed file with 110 additions and 2 deletions.
112 changes: 110 additions & 2 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple

import torch
from torch.utils.flop_counter import (
_flash_attention_backward_flop,
_unpack_flash_attention_nested_shapes,
bmm_flop,
register_flop_formula,
)

from ..common import get_operator, register_operator
from .attn_bias import (
Expand Down Expand Up @@ -51,6 +57,32 @@
_C_flashattention3 = None


# Copied from PyTorch, modified to support MQA/GQA.
# No need to take care of this for the bwd because we don't "unexpand" the keys
# and values (in the fwd we expand to help with the seqlen/headdim swap trick).
def sdpa_flop_count(query_shape, key_shape, value_shape):
"""
Count flops for self-attention.
NB: We can assume that value_shape == key_shape
"""
b, h_q, s_q, d_q = query_shape
_b2, h_kv, s_k, _d2 = key_shape
_b3, _h2, _s3, d_v = value_shape
assert b == _b2 == _b3
assert h_kv == _h2
assert d_q == _d2
assert s_k == _s3
assert d_q == _d2
assert h_q % h_kv == 0
total_flops = 0
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
total_flops += bmm_flop((b * h_q, s_q, d_q), (b * h_q, d_q, s_k))
# scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
total_flops += bmm_flop((b * h_q, s_q, s_k), (b * h_q, s_k, d_v))
return total_flops


if _C_flashattention3 is not None:
# returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p
@torch.library.custom_op(
Expand All @@ -68,7 +100,7 @@ def mha_fwd(
p: float,
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor,]:
) -> Tuple[torch.Tensor, torch.Tensor]:
win_left = win_right = -1
if cu_seqlens_q is None:
use_gqa_packing = False
Expand Down Expand Up @@ -130,7 +162,7 @@ def mha_fwd_fake(
p: float,
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor,]:
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape = query.shape
out = query.new_empty(query_shape)
# Query is (B, M, H, K) or (total_M, H, K)
Expand All @@ -143,6 +175,43 @@ def mha_fwd_fake(
lse = query.new_empty(lse_shape, dtype=torch.float32)
return out, lse

@register_flop_formula(torch.ops.xformers_flash3.flash_fwd, get_raw=True)
def mha_fwd_flops(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
seqused_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
# The FLOPs counter might pass more args (out_val, out_shape, ...)
*args,
**kwargs,
):
assert 3 <= query.ndim <= 4
assert 3 <= key.ndim <= 4
assert 3 <= value.ndim <= 4
sizes = _unpack_flash_attention_nested_shapes(
query=query.transpose(-2, -3) if query.ndim == 4 else query,
key=key.transpose(-2, -3) if key.ndim == 4 else key,
value=value.transpose(-2, -3) if value.ndim == 4 else value,
cum_seq_q=cu_seqlens_q,
cum_seq_k=cu_seqlens_k,
max_q=max_seqlen_q,
max_k=max_seqlen_k,
)
res = sum(
sdpa_flop_count(query_shape, key_shape, value_shape)
for query_shape, key_shape, value_shape, _ in sizes
)
if is_causal:
res /= 2
return res

def _create_dq_dk_dv(
grads_share_storage: bool, query, key, value
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -244,6 +313,45 @@ def mha_bwd_fake(
dv = torch.empty_like(value)
return dq, dk, dv

@register_flop_formula(torch.ops.xformers_flash3.flash_bwd, get_raw=True)
def mha_bwd_flops(
grads_share_storage: bool,
dout: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
is_causal: bool,
# The FLOPs counter might pass more args (out_val, out_shape, ...)
*args,
**kwargs,
):
assert 3 <= dout.ndim <= 4
assert 3 <= query.ndim <= 4
assert 3 <= key.ndim <= 4
assert 3 <= value.ndim <= 4
res = _flash_attention_backward_flop(
dout.transpose(-2, -3) if dout.ndim == 4 else dout,
query.transpose(-2, -3) if query.ndim == 4 else query,
key.transpose(-2, -3) if key.ndim == 4 else key,
value.transpose(-2, -3) if value.ndim == 4 else value,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
)
if is_causal:
res /= 2
return res


@register_operator
class FwOp(AttentionFwOpBase):
Expand Down

0 comments on commit 536363e

Please sign in to comment.