From 871994ee0fe77c1124ab14ed69c31fcc99e85211 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 30 Aug 2024 17:20:18 -0700 Subject: [PATCH] fixes --- .eggs/README.txt | 6 ++ flash_attn/bert_padding.py | 7 +-- flash_attn/flash_blocksparse_attention.py | 2 +- flash_attn/models/bert.py | 2 +- hopper/flash_api.cpp | 45 +++++++++------ hopper/flash_attn_interface.py | 4 +- hopper/flash_bwd_kernel.h | 26 +++++---- hopper/flash_bwd_launch_template.h | 3 +- hopper/flash_fwd_launch_template.h | 6 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 27 +++++---- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 21 +++---- hopper/setup.py | 2 +- hopper/test_flash_attn.py | 70 +++++++++++++++-------- hopper/tile_scheduler.hpp | 2 +- tests/test_flash_attn.py | 6 +- tests/test_rotary.py | 2 +- 16 files changed, 134 insertions(+), 97 deletions(-) create mode 100644 .eggs/README.txt diff --git a/.eggs/README.txt b/.eggs/README.txt new file mode 100644 index 000000000..5d0166882 --- /dev/null +++ b/.eggs/README.txt @@ -0,0 +1,6 @@ +This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins. + +This directory caches those eggs to prevent repeated downloads. + +However, it is safe to delete this directory. + diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 083a79441..71ab43d8f 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -119,16 +119,13 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None): # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. - res = ( + return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, + used_seqlens_in_batch, ) - if unused_mask is not None: - return res + (used_seqlens_in_batch, ) - else: - return res def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): diff --git a/flash_attn/flash_blocksparse_attention.py b/flash_attn/flash_blocksparse_attention.py index 03798d16f..4c9302910 100644 --- a/flash_attn/flash_blocksparse_attention.py +++ b/flash_attn/flash_blocksparse_attention.py @@ -99,7 +99,7 @@ def forward( key_padding_mask_bool = key_padding_mask.bool_matrix nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_blocksparse_attn_func( x_unpad, diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 33d693520..6a78b1ea9 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -172,7 +172,7 @@ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): hidden_states = hidden_states[subset_mask] else: batch, seqlen = hidden_states.shape[:2] - hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( + hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input( hidden_states, key_padding_mask ) mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index a5c195883..638752e4d 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -130,17 +130,18 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - if ((window_size_left >= 0 || window_size_right >= 0) && !params.is_causal) { - params.is_local = true; - } window_size_left = std::min(int(seqlen_k), window_size_left); window_size_right = std::min(int(seqlen_k), window_size_right); - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + if (window_size_left < 0) { window_size_left = seqlen_k; } + if (window_size_right < 0) { window_size_right = seqlen_k; } params.window_size_left = window_size_left; params.window_size_right = window_size_right; + params.is_causal = window_size_left == seqlen_k && window_size_right == 0; + if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) { + params.is_local = true; + } + #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), "This flash attention build does not support local attention."); @@ -356,6 +357,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -381,8 +384,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/is_causal ? -1 : window_size_left, - /*window_size_right=*/is_causal ? 0 : window_size_right); + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right); auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -536,6 +539,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -559,8 +564,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - is_causal ? -1 : window_size_left, - is_causal ? 0 : window_size_right, + window_size_left, + window_size_right, /*seqlenq_ngroups_swapped=*/false, /*unpadded_lse=*/true); params.total_q = total_q; @@ -621,8 +626,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const float softmax_scale, const bool is_causal, - const int window_size_left, - const int window_size_right, + int window_size_left, + int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -739,6 +744,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv_expanded = dv; } + if (is_causal) { window_size_right = 0; } + Flash_bwd_params params; set_params_dgrad(params, @@ -762,8 +769,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/is_causal ? -1 : window_size_left, - /*window_size_right=*/is_causal ? 0 : window_size_right, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, deterministic); params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); @@ -814,8 +821,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x const int max_seqlen_k, // max sequence length to choose the kernel const float softmax_scale, const bool is_causal, - const int window_size_left, - const int window_size_right, + int window_size_left, + int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -932,6 +939,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x dout_padded = dout; } + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -978,8 +987,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/is_causal ? -1 : window_size_left, - /*window_size_right=*/is_causal ? 0 : window_size_right, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, deterministic); params.total_q = total_q; params.total_k = total_k; diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 2a610c998..fffa69c94 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -212,7 +212,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None + return dq, dk, dv, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -289,7 +289,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None def flash_attn_func( diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h index 61eee8bb2..63cfd7892 100644 --- a/hopper/flash_bwd_kernel.h +++ b/hopper/flash_bwd_kernel.h @@ -31,6 +31,7 @@ class FlashAttnBwd { // Type Aliases static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); static constexpr bool Varlen = CollectiveMainloop_::Varlen; @@ -155,6 +156,7 @@ class FlashAttnBwd { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -218,12 +220,12 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; } } - if constexpr (Is_causal) { + if constexpr (Is_causal || Is_local) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { @@ -247,13 +249,13 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } - } - if constexpr (Is_causal) { - int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); - if (m_block_min >= m_block_max) { continue; } + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } + // if constexpr (Is_causal || Is_local) { + // int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); + // int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); + // if (m_block_min >= m_block_max) { continue; } + // } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } @@ -277,11 +279,14 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } - if constexpr (Is_causal) { + if constexpr (Is_causal || Is_local) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); + auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb); + auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb); + auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM); if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); continue; @@ -300,7 +305,6 @@ class FlashAttnBwd { } collective_epilogue.store_tail(); } - } }; diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index a06de561c..2fe165584 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -88,7 +88,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, - params.seqused_q, params.seqused_k + params.seqused_q, params.seqused_k, + params.window_size_left, params.window_size_right }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.dk_ptr), diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 60336a99e..2ed052162 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -29,9 +29,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveMainloop = flash::CollectiveMainloopFwd; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; using Scheduler = std::conditional_t< - Seqlen_traits::kUseVarSeqLen, + Seqlen_traits::kUseVarSeqLen || Is_local, flash::SingleTileScheduler, - std::conditional_t >>; @@ -137,7 +137,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { run_flash_fwd< - Flash_fwd_kernel_traits, + Flash_fwd_kernel_traits, Is_causal, Is_local, Seqlen_traits >(params, stream); }); diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 24884d406..334cdcacd 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -435,7 +435,7 @@ struct CollectiveMainloopBwd { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_causal || Is_local) { + if constexpr (Is_local) { return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); } else { return m_block_max; @@ -613,6 +613,15 @@ struct CollectiveMainloopBwd { } cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to } + if constexpr (Deterministic) { + constexpr int kBlockM = get<0>(TileShape_MNK{}); + int const seqlen_q = get_seqlen_q(params, bidb); + int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM); + #pragma unroll 2 + for (; m_block < m_block_global_max; ++m_block) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } } CUTLASS_DEVICE void @@ -745,7 +754,7 @@ struct CollectiveMainloopBwd { // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. - if constexpr (Is_causal || Is_local) { + if constexpr (Is_causal) { static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) { @@ -761,17 +770,12 @@ struct CollectiveMainloopBwd { warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); - int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right; - int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; + int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { if (int(get<0>(taccScS(i))) >= - std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) { + std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; - } else if constexpr (Is_local) { - if (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left)) { - tSrS(i) = -INFINITY; - } } } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) @@ -836,9 +840,8 @@ struct CollectiveMainloopBwd { int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if ((int(get<0>(taccScS(i))) >= - std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN) - ) || (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left))) { + if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) || + (int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) { tSrS(i) = -INFINITY; } } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3ea4ee14e..111421580 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -769,11 +769,12 @@ struct CollectiveMainloopFwd { } softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); - constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + constexpr int n_masking_steps = (!Is_causal) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) { @@ -791,12 +792,8 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) { + if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1)) { tSrS(i) = -INFINITY; - } else if constexpr (Is_local) { - if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) { - tSrS(i) = -INFINITY; - } } } cute::copy(softmax.template max(tSrS), scores_scale); @@ -827,16 +824,17 @@ struct CollectiveMainloopFwd { #pragma unroll for (int i = 0; i < size(tSrS); ++i) { if ( - int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block) - || int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block) + int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1) || + int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block - 1) ) { tSrS(i) = -INFINITY; } } } // auto scores_scale = softmax.template max(tSrS); - cute::copy(softmax.template max(tSrS), scores_scale); - softmax.template online_softmax(tSrS); + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.template online_softmax(tSrS); + warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; @@ -849,11 +847,10 @@ struct CollectiveMainloopFwd { softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - cute::copy(softmax.template finalize(tSrS), scores_scale); + cute::copy(softmax.template finalize(tSrS), scores_scale); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang ++smem_pipe_read_v; - softmax.rescale_o(tOrO, scores_scale); return; } diff --git a/hopper/setup.py b/hopper/setup.py index 12e86667c..2a23d1657 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -144,7 +144,7 @@ def append_nvcc_threads(nvcc_extra_args): "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers "-lineinfo", "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging - "-DNDEBUG", # Important, otherwise performance is severely impacted + "-DNDEBUG", # Important, otherwise performance is severely impacted ] include_dirs = [ # Path(this_dir) / "fmha-pipeline", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 6a098f7c5..5c15d02d0 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -1,5 +1,8 @@ import math +import sys +sys.path.remove("/home/yingz/llm_inference") + import pytest import torch import torch.nn.functional as F @@ -24,11 +27,14 @@ def print_diffs(out, out_ref): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -37,7 +43,7 @@ def print_diffs(out, out_ref): # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) -# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @@ -65,13 +71,13 @@ def print_diffs(out, out_ref): ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype, descale + seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: - dtype_init = dtype + dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) @@ -80,9 +86,11 @@ def test_flash_attn_output( batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # nheads_kv = 2 - # batch_size = 9 - # nheads = 6 + # nheads_kv = 1 + # batch_size = 1 + # nheads = 1 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + print(f"window_size: {window_size}", flush=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) @@ -96,7 +104,7 @@ def test_flash_attn_output( descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): - out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic) + out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v @@ -113,7 +121,7 @@ def test_flash_attn_output( q = q * descale_q k = k * descale_k v = v * descale_v - + out_ref, attn_ref = attention_ref( q, k, @@ -121,6 +129,7 @@ def test_flash_attn_output( None, None, causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -129,6 +138,7 @@ def test_flash_attn_output( None, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -144,9 +154,9 @@ def test_flash_attn_output( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - + # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: @@ -181,7 +191,7 @@ def test_flash_attn_output( # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 - else: + else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() @@ -196,14 +206,16 @@ def test_flash_attn_output( @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [128]) +# @pytest.mark.parametrize('d', [256]) # @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128]) @@ -233,7 +245,7 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, deterministic, add_unused_qkv, mha_type, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -243,12 +255,15 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(0) - # batch_size = 1 - # nheads = 1 batch_size = 9 - nheads = 6 - nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - + nheads = 4 + nheads_kv = 4 + # batch_size = 9 + # nheads = 6 + # nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True @@ -308,10 +323,12 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): deterministic=deterministic, seqused_q=seqused_q, seqused_k=seqused_k, + window_size=window_size, ) out = output_pad_fn(out_unpad) - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - out.masked_fill_(q_zero_masking, 0.0) + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + out.masked_fill_(q_zero_masking, 0.0) dropout_mask = None out_ref, attn_ref = attention_ref( @@ -321,6 +338,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -329,6 +347,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -347,9 +366,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) - k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") - dk.masked_fill_(k_zero_masking, 0.0) - dv.masked_fill_(k_zero_masking, 0.0) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) ( dq_ref, dk_ref, diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index ac93ca94b..2fbb417e4 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -270,4 +270,4 @@ class DynamicPersistentTileScheduler { }; -} // flash \ No newline at end of file +} // flash diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 72d55134e..d5bb6ba85 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -89,7 +89,7 @@ def generate_qkv( assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -104,8 +104,8 @@ def generate_qkv( ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 6f2a5fae7..f6b3e5aed 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -215,7 +215,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of x_pt = x.detach().clone().requires_grad_() lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths - x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask) + x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask) x_unpad_clone = x_unpad.clone() x_unpad = x_unpad.requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)