11# SPDX-License-Identifier: MIT
2- # Copyright (c ) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+ # Copyright (C ) 2024-2025 , Advanced Micro Devices, Inc. All rights reserved.
33
44from torch import Tensor , Generator
55from typing import Optional , Tuple
66from ..jit .core import compile_ops , CK_DIR , AITER_CSRC_DIR , AITER_ROOT_DIR
77from ..utility import dtypes
88import torch
99
10+
1011@compile_ops ("module_mha_fwd" , fc_name = "mha_fwd" )
1112def mha_fwd (
1213 q : Tensor ,
@@ -48,7 +49,7 @@ def mha_varlen_fwd(
4849 bias : Optional [Tensor ] = None ,
4950 alibi_slopes : Optional [Tensor ] = None ,
5051 gen : Optional [Generator ] = None ,
51- ): ...
52+ ) -> list [ Tensor ] : ...
5253
5354
5455@compile_ops ("module_mha_bwd" , fc_name = "mha_bwd" )
@@ -419,7 +420,9 @@ def pssk():
419420 # bwd_hd64_bf16_causal_a32_rtz_pssk
420421 # bwd_hd64_fp16_a32_pssk
421422 # bwd_hd64_fp16_causal_a32_pssk
422- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423+ ret = (
424+ is_v3_atomic_fp32 == True
425+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423426 ret &= hdim_q == 64
424427 ret &= nmask or (
425428 mask and seqlen_q == seqlen_k
@@ -474,7 +477,9 @@ def psskddv():
474477 # bwd_hd192_bf16_causal_a32_rtz_psskddv
475478 ret = is_v3_atomic_fp32 == True
476479 ret &= hdim_q > 64 and hdim_q <= 192
477- ret &= nmask or (mask and seqlen_q == seqlen_k ) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
480+ ret &= nmask or (
481+ mask and seqlen_q == seqlen_k
482+ ) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
478483
479484 return ret
480485
@@ -759,6 +764,7 @@ def _flash_attn_varlen_forward(
759764 return_lse : bool = False ,
760765 return_softmax : bool = False ,
761766 block_table : Optional [torch .Tensor ] = None ,
767+ out : Optional [torch .Tensor ] = None ,
762768 zero_tensors : bool = False ,
763769) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
764770 # causal=true is the same as causal=false in this case
@@ -878,7 +884,7 @@ def _flash_attn_varlen_forward(
878884 window_size_right ,
879885 return_lse ,
880886 return_softmax ,
881- None ,
887+ out ,
882888 block_table ,
883889 bias ,
884890 alibi_slopes ,
@@ -963,7 +969,9 @@ def _flash_attn_varlen_backward(
963969 ]
964970
965971 (_ , nhead_q , hdim_q ) = q .shape
966- (_ , nhead_k , hdim_v ) = v .shape
972+
973+ nhead_k = v .shape [- 2 ]
974+ hdim_v = v .shape [- 1 ]
967975
968976 # mask
969977 window_size_left = - 1 if window_size_left >= max_seqlen_k else window_size_left
@@ -994,12 +1002,14 @@ def pssk():
9941002 # bwd_hd128_bf16_causal_a32_rtz_pssk_group
9951003 # bwd_hd128_fp16_a32_pssk_group
9961004 # bwd_hd128_fp16_causal_a32_pssk_group
997- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1005+ ret = (
1006+ is_v3_atomic_fp32 == True
1007+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
9981008 ret &= hdim_q == 64 or hdim_q == 128
999- ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1009+ ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
10001010
10011011 return ret
1002-
1012+
10031013 def psskddv ():
10041014 # bwd_hd128_bf16_a32_rtne_psskddv_group
10051015 # bwd_hd128_bf16_a32_rtna_psskddv_group
@@ -1009,9 +1019,11 @@ def psskddv():
10091019 # bwd_hd128_bf16_causal_a32_rtz_psskddv_group
10101020 # bwd_hd128_fp16_a32_psskddv_group
10111021 # bwd_hd128_fp16_causal_a32_psskddv_group
1012- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1022+ ret = (
1023+ is_v3_atomic_fp32 == True
1024+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
10131025 ret &= hdim_q > 64 and hdim_q < 128
1014- ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1026+ ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
10151027
10161028 return ret
10171029
@@ -1027,7 +1039,7 @@ def can_impl_fmha_v3_bwd():
10271039 ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0
10281040 ret &= mask or nmask
10291041 ret &= pssk () or psskddv ()
1030- ret &= ' gfx942' in torch .cuda .get_device_properties ("cuda" ).gcnArchName
1042+ ret &= " gfx942" in torch .cuda .get_device_properties ("cuda" ).gcnArchName
10311043
10321044 return ret
10331045
@@ -1122,15 +1134,16 @@ def forward(
11221134 return_lse ,
11231135 return_softmax ,
11241136 block_table ,
1137+ out ,
11251138 is_grad_enabled ,
11261139 is_v3_atomic_fp32 : Optional [bool ] = True ,
11271140 how_v3_bf16_cvt : Optional [int ] = 1 ,
11281141 ):
11291142 is_grad = is_grad_enabled and any (x .requires_grad for x in [q , k , v ])
11301143 if softmax_scale is None :
11311144 softmax_scale = q .shape [- 1 ] ** (- 0.5 )
1132- head_size_q_og = q .size (2 )
1133- head_size_v_og = v .size (2 )
1145+ head_size_q_og = q .size (- 1 )
1146+ head_size_v_og = v .size (- 1 )
11341147 if head_size_q_og % 8 != 0 :
11351148 q = torch .nn .functional .pad (q , [0 , 8 - head_size_q_og % 8 ])
11361149 k = torch .nn .functional .pad (k , [0 , 8 - head_size_q_og % 8 ])
@@ -1154,6 +1167,7 @@ def forward(
11541167 return_lse = return_lse ,
11551168 return_softmax = return_softmax and dropout_p > 0 ,
11561169 block_table = block_table ,
1170+ out = out ,
11571171 )
11581172 if is_grad :
11591173 ctx .save_for_backward (
@@ -1243,6 +1257,7 @@ def backward(ctx, dout, *args):
12431257 None ,
12441258 None ,
12451259 None ,
1260+ None ,
12461261 )
12471262
12481263
@@ -1264,6 +1279,7 @@ def flash_attn_varlen_func(
12641279 return_lse = False ,
12651280 return_attn_probs = False ,
12661281 block_table = None ,
1282+ out = None ,
12671283):
12681284 """dropout_p should be set to 0.0 during evaluation
12691285 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1338,5 +1354,6 @@ def flash_attn_varlen_func(
13381354 return_lse ,
13391355 return_attn_probs ,
13401356 block_table ,
1357+ out ,
13411358 torch .is_grad_enabled (),
13421359 )
0 commit comments