diff --git a/xformers/ops/fmha/flash3.py b/xformers/ops/fmha/flash3.py index 900b9fdd5..23e57e769 100644 --- a/xformers/ops/fmha/flash3.py +++ b/xformers/ops/fmha/flash3.py @@ -7,6 +7,12 @@ from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple import torch +from torch.utils.flop_counter import ( + _flash_attention_backward_flop, + _unpack_flash_attention_nested_shapes, + bmm_flop, + register_flop_formula, +) from ..common import get_operator, register_operator from .attn_bias import ( @@ -51,6 +57,32 @@ _C_flashattention3 = None +# Copied from PyTorch, modified to support MQA/GQA. +# No need to take care of this for the bwd because we don't "unexpand" the keys +# and values (in the fwd we expand to help with the seqlen/headdim swap trick). +def sdpa_flop_count(query_shape, key_shape, value_shape): + """ + Count flops for self-attention. + + NB: We can assume that value_shape == key_shape + """ + b, h_q, s_q, d_q = query_shape + _b2, h_kv, s_k, _d2 = key_shape + _b3, _h2, _s3, d_v = value_shape + assert b == _b2 == _b3 + assert h_kv == _h2 + assert d_q == _d2 + assert s_k == _s3 + assert d_q == _d2 + assert h_q % h_kv == 0 + total_flops = 0 + # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h_q, s_q, d_q), (b * h_q, d_q, s_k)) + # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v] + total_flops += bmm_flop((b * h_q, s_q, s_k), (b * h_q, s_k, d_v)) + return total_flops + + if _C_flashattention3 is not None: # returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p @torch.library.custom_op( @@ -68,7 +100,7 @@ def mha_fwd( p: float, softmax_scale: float, is_causal: bool, - ) -> Tuple[torch.Tensor, torch.Tensor,]: + ) -> Tuple[torch.Tensor, torch.Tensor]: win_left = win_right = -1 if cu_seqlens_q is None: use_gqa_packing = False @@ -130,7 +162,7 @@ def mha_fwd_fake( p: float, softmax_scale: float, is_causal: bool, - ) -> Tuple[torch.Tensor, torch.Tensor,]: + ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape = query.shape out = query.new_empty(query_shape) # Query is (B, M, H, K) or (total_M, H, K) @@ -143,6 +175,43 @@ def mha_fwd_fake( lse = query.new_empty(lse_shape, dtype=torch.float32) return out, lse + @register_flop_formula(torch.ops.xformers_flash3.flash_fwd, get_raw=True) + def mha_fwd_flops( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + p: float, + softmax_scale: float, + is_causal: bool, + # The FLOPs counter might pass more args (out_val, out_shape, ...) + *args, + **kwargs, + ): + assert 3 <= query.ndim <= 4 + assert 3 <= key.ndim <= 4 + assert 3 <= value.ndim <= 4 + sizes = _unpack_flash_attention_nested_shapes( + query=query.transpose(-2, -3) if query.ndim == 4 else query, + key=key.transpose(-2, -3) if key.ndim == 4 else key, + value=value.transpose(-2, -3) if value.ndim == 4 else value, + cum_seq_q=cu_seqlens_q, + cum_seq_k=cu_seqlens_k, + max_q=max_seqlen_q, + max_k=max_seqlen_k, + ) + res = sum( + sdpa_flop_count(query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, _ in sizes + ) + if is_causal: + res /= 2 + return res + def _create_dq_dk_dv( grads_share_storage: bool, query, key, value ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -244,6 +313,45 @@ def mha_bwd_fake( dv = torch.empty_like(value) return dq, dk, dv + @register_flop_formula(torch.ops.xformers_flash3.flash_bwd, get_raw=True) + def mha_bwd_flops( + grads_share_storage: bool, + dout: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + # The FLOPs counter might pass more args (out_val, out_shape, ...) + *args, + **kwargs, + ): + assert 3 <= dout.ndim <= 4 + assert 3 <= query.ndim <= 4 + assert 3 <= key.ndim <= 4 + assert 3 <= value.ndim <= 4 + res = _flash_attention_backward_flop( + dout.transpose(-2, -3) if dout.ndim == 4 else dout, + query.transpose(-2, -3) if query.ndim == 4 else query, + key.transpose(-2, -3) if key.ndim == 4 else key, + value.transpose(-2, -3) if value.ndim == 4 else value, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + ) + if is_causal: + res /= 2 + return res + @register_operator class FwOp(AttentionFwOpBase):