From 223b148dddc94405cb565c67a521cee4ee217645 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Thu, 29 Aug 2024 22:34:31 -0700 Subject: [PATCH] hopper local attention --- hopper/flash.h | 1 + hopper/flash_api.cpp | 41 ++++++----- hopper/flash_attn_interface.py | 32 +++++++-- hopper/flash_bwd_kernel.h | 6 +- hopper/flash_bwd_launch_template.h | 28 +++++--- hopper/flash_fwd_kernel.h | 24 ++++--- hopper/flash_fwd_launch_template.h | 68 ++++++++++-------- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 68 ++++++++++++++---- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 90 +++++++++++++++++++----- 9 files changed, 252 insertions(+), 106 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 272bec7d1..24ca27f69 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params { bool is_bf16; bool is_e4m3; bool is_causal; + bool is_local; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 2ea08c247..a5c195883 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -131,7 +131,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0; - + if ((window_size_left >= 0 || window_size_right >= 0) && !params.is_causal) { + params.is_local = true; + } + window_size_left = std::min(int(seqlen_k), window_size_left); + window_size_right = std::min(int(seqlen_k), window_size_right); if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } params.window_size_left = window_size_left; @@ -273,7 +277,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size c10::optional &descale_q_, // 1 c10::optional &descale_k_, // 1 c10::optional &descale_v_, // 1 - bool is_causal) { + bool is_causal, + int window_size_left, + int window_size_right) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -375,8 +381,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1); + /*window_size_left=*/is_causal ? -1 : window_size_left, + /*window_size_right=*/is_causal ? 0 : window_size_right); auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -437,7 +443,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int max_seqlen_q, const int max_seqlen_k, const float softmax_scale, - bool is_causal) { + bool is_causal, + int window_size_left, + int window_size_right) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -468,10 +476,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int head_size_og = sizes[2]; const int num_heads_k = k.size(1); - int window_size_left = -1; - int window_size_right = -1; - if (is_causal) { window_size_right = 0; } - void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); const int total_q = q.sizes()[0]; @@ -480,9 +484,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (window_size_left >= max_seqlen_k) { window_size_left = -1; } - if (window_size_right >= max_seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, total_q, num_heads, head_size_og); const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); @@ -558,8 +559,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - window_size_left, - window_size_right, + is_causal ? -1 : window_size_left, + is_causal ? 0 : window_size_right, /*seqlenq_ngroups_swapped=*/false, /*unpadded_lse=*/true); params.total_q = total_q; @@ -620,6 +621,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const float softmax_scale, const bool is_causal, + const int window_size_left, + const int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -759,8 +762,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1, + /*window_size_left=*/is_causal ? -1 : window_size_left, + /*window_size_right=*/is_causal ? 0 : window_size_right, deterministic); params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); @@ -811,6 +814,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x const int max_seqlen_k, // max sequence length to choose the kernel const float softmax_scale, const bool is_causal, + const int window_size_left, + const int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -973,8 +978,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1, + /*window_size_left=*/is_causal ? -1 : window_size_left, + /*window_size_right=*/is_causal ? 0 : window_size_right, deterministic); params.total_q = total_q; params.total_k = total_k; diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 4571eb394..2a610c998 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None): +def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -26,6 +26,8 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descal descale_k, descale_v, causal, + window_size[0], + window_size[1], ) return out, q, k, v, out_padded, softmax_lse, S_dmask @@ -42,6 +44,7 @@ def _flash_attn_backward( dv, softmax_scale, causal, + window_size, deterministic=False ): # dq, dk, dv are allocated by us so they should already be contiguous @@ -58,6 +61,8 @@ def _flash_attn_backward( dv, softmax_scale, causal, + window_size[0], + window_size[1], deterministic, ) return dq, dk, dv, softmax_d @@ -72,6 +77,7 @@ def _flash_attn_varlen_forward( max_seqlen_k, softmax_scale, causal, + window_size=(-1, -1), seqused_q=None, seqused_k=None, ): @@ -90,6 +96,8 @@ def _flash_attn_varlen_forward( max_seqlen_k, softmax_scale, causal, + window_size[0], + window_size[1], ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -112,6 +120,7 @@ def _flash_attn_varlen_backward( max_seqlen_k, softmax_scale, causal, + window_size, deterministic=False, seqused_q=None, seqused_k=None, @@ -143,6 +152,8 @@ def _flash_attn_varlen_backward( max_seqlen_k, softmax_scale, causal, + window_size[0], + window_size[1], deterministic, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): @@ -159,6 +170,7 @@ def forward( v, softmax_scale, causal, + window_size, deterministic=False, ): if softmax_scale is None: @@ -168,11 +180,13 @@ def forward( k, v, softmax_scale, - causal + causal, + window_size ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.deterministic = deterministic return out, softmax_lse @@ -192,6 +206,7 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, + ctx.window_size, ctx.deterministic, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -213,6 +228,7 @@ def forward( max_seqlen_k, softmax_scale, causal, + window_size, deterministic=False, seqused_q=None, seqused_k=None, @@ -229,6 +245,7 @@ def forward( max_seqlen_k, softmax_scale, causal=causal, + window_size=window_size, seqused_q=seqused_q, seqused_k=seqused_k, ) @@ -240,6 +257,7 @@ def forward( ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.deterministic = deterministic return out, softmax_lse @@ -263,6 +281,7 @@ def backward(ctx, dout, *args): ctx.max_seqlen_k, ctx.softmax_scale, ctx.causal, + ctx.window_size, ctx.deterministic, seqused_q, seqused_k, @@ -279,6 +298,7 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + window_size=(-1, -1), deterministic=False ): """dropout_p should be set to 0.0 during evaluation @@ -335,6 +355,7 @@ def flash_attn_func( v, softmax_scale, causal, + window_size, deterministic, ) @@ -349,6 +370,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale=None, causal=False, + window_size=(-1, -1), deterministic=False, seqused_q=None, seqused_k=None, @@ -382,9 +404,10 @@ def flash_attn_varlen_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of query and output tokens in each sequence. - seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of key and value tokens in each sequence. Return: out: (total, nheads, headdim). @@ -402,6 +425,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + window_size, deterministic, seqused_q, seqused_k, diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h index f4ba0ff47..61eee8bb2 100644 --- a/hopper/flash_bwd_kernel.h +++ b/hopper/flash_bwd_kernel.h @@ -225,7 +225,7 @@ class FlashAttnBwd { } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; @@ -251,7 +251,7 @@ class FlashAttnBwd { } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { continue; } } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); @@ -281,7 +281,7 @@ class FlashAttnBwd { } if constexpr (Is_causal) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); continue; diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 1b0a852b6..a06de561c 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -20,7 +20,7 @@ using namespace cute; -template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using TileShape_MK = cute::Shape, Int>; @@ -57,7 +57,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape<_1, Int<1>, _1>; static constexpr int Stages = 2; using CollectiveMainloop = flash::CollectiveMainloopBwd; using CollectiveEpilogue = flash::CollectiveEpilogueBwd; using Scheduler = flash::SingleTileSchedulerBwd; @@ -170,9 +170,11 @@ template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); @@ -182,9 +184,11 @@ template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); @@ -194,9 +198,11 @@ template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 5d79ddeb9..9517c5e0c 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -24,9 +24,9 @@ namespace flash { using namespace cute; -template +template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k @@ -47,7 +47,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, // static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; - using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveMainloop = CollectiveMainloopFwd; using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -121,9 +121,11 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { continue; } - int n_block_max = collective_mainloop.get_n_block_max( + const int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = collective_mainloop.get_n_block_min( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { + if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); continue; @@ -167,15 +169,17 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { continue; } - int n_block_max = collective_mainloop.get_n_block_max( + const int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = collective_mainloop.get_n_block_min( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); continue; } collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, - tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, + tOrO, softmax, n_block_max, n_block_min, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, seqlen_traits_q, seqlen_traits_k); // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, @@ -190,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k @@ -215,7 +219,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128; static constexpr bool Use_max_offset = true; - using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveMainloop = CollectiveMainloopFwd; using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 0e51769bb..60336a99e 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -18,7 +18,7 @@ #include "utils.h" -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using Element = typename Kernel_traits::Element; using OutputType = typename Kernel_traits::OutputType; @@ -26,12 +26,12 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = typename Kernel_traits::ClusterShape_MNK; // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); - using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveMainloop = flash::CollectiveMainloopFwd; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; using Scheduler = std::conditional_t< Seqlen_traits::kUseVarSeqLen, flash::SingleTileScheduler, - std::conditional_t >>; @@ -60,7 +60,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.scale_softmax_log2, params.descale_q_ptr, params.descale_k_ptr, - params.descale_v_ptr + params.descale_v_ptr, + params.window_size_left, + params.window_size_right }); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ @@ -85,7 +87,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if constexpr(cutlass::sizeof_bits_v == 8) kernel = (void *)flash::compute_attn_ws_fp8; else - kernel = (void *)flash::compute_attn_ws; + kernel = (void *)flash::compute_attn_ws; int smem_size = sizeof(typename Kernel_traits::SharedStorage); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); @@ -115,11 +117,13 @@ template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); } @@ -128,13 +132,15 @@ template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); }); @@ -144,13 +150,15 @@ template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); }); @@ -166,11 +174,11 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -195,11 +203,11 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -224,11 +232,11 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index b54c2b5d1..24884d406 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -24,7 +24,7 @@ namespace flash { using namespace cute; template struct CollectiveMainloopBwd { @@ -36,6 +36,7 @@ struct CollectiveMainloopBwd { using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; static constexpr bool Varlen = Varlen_; static constexpr bool SdP_swapAB = true; static constexpr bool dKV_swapAB = dKV_swapAB_; @@ -281,6 +282,8 @@ struct CollectiveMainloopBwd { int const* cu_seqlens_k = nullptr; int const* seqused_k = nullptr; int const* seqused_v = nullptr; + int window_size_left; + int window_size_right; }; // Device side kernel params @@ -307,6 +310,8 @@ struct CollectiveMainloopBwd { int const* cu_seqlens_k = nullptr; int const* seqused_q = nullptr; int const* seqused_k = nullptr; + int window_size_left; + int window_size_right; }; static Params @@ -367,7 +372,7 @@ struct CollectiveMainloopBwd { args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, float(args.softmax_scale * M_LOG2E), args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, - args.seqused_k, args.seqused_v}; + args.seqused_k, args.seqused_v, args.window_size_left, args.window_size_right}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -412,15 +417,31 @@ struct CollectiveMainloopBwd { CUTLASS_DEVICE int get_m_block_min(Params const& params, int n_block, int bidb) { - if constexpr (Is_causal) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + if constexpr (Is_causal || Is_local) { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); - return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM); + return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); } else { return 0; } } + CUTLASS_DEVICE + int get_m_block_max(Params const& params, int n_block, int bidb) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = get_seqlen_q(params, bidb); + int const seqlen_k = get_seqlen_k(params, bidb); + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_causal || Is_local) { + return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); + } else { + return m_block_max; + } + } + template CUTLASS_DEVICE void load(Params const& params, @@ -491,7 +512,7 @@ struct CollectiveMainloopBwd { } } - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; @@ -568,7 +589,7 @@ struct CollectiveMainloopBwd { Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; int const num_batch = params.num_batch; @@ -678,7 +699,7 @@ struct CollectiveMainloopBwd { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; @@ -724,7 +745,7 @@ struct CollectiveMainloopBwd { // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. - if constexpr (Is_causal) { + if constexpr (Is_causal || Is_local) { static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) { @@ -740,12 +761,17 @@ struct CollectiveMainloopBwd { warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); - int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right; + int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset, - seqlen_k - n_block * kBlockN)) { + if (int(get<0>(taccScS(i))) >= + std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; + } else if constexpr (Is_local) { + if (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left)) { + tSrS(i) = -INFINITY; + } } } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) @@ -800,10 +826,24 @@ struct CollectiveMainloopBwd { warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + if constexpr (!Is_local) { + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } + } else { + int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right; + int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if ((int(get<0>(taccScS(i))) >= + std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN) + ) || (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left))) { + tSrS(i) = -INFINITY; + } + } } + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 8094ad3ab..3ea4ee14e 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -79,7 +79,7 @@ struct SmemTransposeFp8_64x64 { } }; -template +template struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; @@ -158,6 +158,8 @@ struct CollectiveMainloopFwd { float const* descale_q_ptr; float const* descale_k_ptr; float const* descale_v_ptr; + int window_size_left; + int window_size_right; }; // Device side kernel params @@ -173,6 +175,8 @@ struct CollectiveMainloopFwd { float const* descale_q_ptr; float const* descale_k_ptr; float const* descale_v_ptr; + int window_size_left; + int window_size_right; }; @@ -203,7 +207,8 @@ struct CollectiveMainloopFwd { cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2, - args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr}; + args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr, + args.window_size_left, args.window_size_right}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -225,13 +230,34 @@ struct CollectiveMainloopFwd { int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN)); + if constexpr (Is_causal || Is_local) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN)); } return n_block_max; } + CUTLASS_DEVICE + int get_n_block_min( + Params const& mainloop_params, int m_block, + const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k + ) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); + int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); + if constexpr (!Is_local) { + return 0; + } else { + return std::max( + 0, + (m_block * kBlockM + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN + ); + } + } + template CUTLASS_DEVICE void load(Params const& mainloop_params, @@ -288,7 +314,8 @@ struct CollectiveMainloopFwd { } } - int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = get_n_block_min(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -315,7 +342,7 @@ struct CollectiveMainloopFwd { if (lane_predicate) { // CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 2 - for (; n_block > 0; --n_block) { + for (; n_block > n_block_min; --n_block) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index())); @@ -645,6 +672,7 @@ struct CollectiveMainloopFwd { FrgTensorO& tOrO, Softmax& softmax, int n_block_count, + int n_block_min, int thread_idx, int work_idx, int m_block, @@ -706,23 +734,35 @@ struct CollectiveMainloopFwd { pipeline_k.consumer_release(smem_pipe_read_k); ++smem_pipe_read_k; - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + auto col_limit_right = [&](int row, int n_block) { + return std::min( + seqlen_k - n_block * kBlockN, + row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + mainloop_params.window_size_right + ); + }; + auto col_limit_left = [&](int row, int n_block) { + return std::max( + 0, + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - mainloop_params.window_size_left + ); }; { Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col + if constexpr (!Is_causal && !Is_local) { // Just masking based on col if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } else { // mask based on both row and col // using std::min is faster than doing col >= limit0 or col >= limit1 // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the // right hand side can be negative and might be converted to a very large unsigned integer. - if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, - col_limit_causal(int(get<0>(tScS(i))), n_block))) { + if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) { tSrS(i) = -INFINITY; + } else if constexpr (Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) { + tSrS(i) = -INFINITY; + } } } } @@ -733,10 +773,10 @@ struct CollectiveMainloopFwd { Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) { + for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); @@ -751,8 +791,12 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) { + if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) { tSrS(i) = -INFINITY; + } else if constexpr (Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) { + tSrS(i) = -INFINITY; + } } } cute::copy(softmax.template max(tSrS), scores_scale); @@ -765,7 +809,7 @@ struct CollectiveMainloopFwd { } #pragma unroll 1 - for (; n_block > 0; --n_block) { + for (; n_block > n_block_min; --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); @@ -776,6 +820,20 @@ struct CollectiveMainloopFwd { warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if ( + int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block) + || int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block) + ) { + tSrS(i) = -INFINITY; + } + } + } // auto scores_scale = softmax.template max(tSrS); cute::copy(softmax.template max(tSrS), scores_scale); softmax.template online_softmax(tSrS);