diff --git a/.eggs/README.txt b/.eggs/README.txt new file mode 100644 index 000000000..5d0166882 --- /dev/null +++ b/.eggs/README.txt @@ -0,0 +1,6 @@ +This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins. + +This directory caches those eggs to prevent repeated downloads. + +However, it is safe to delete this directory. + diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 1d447d3f6..71ab43d8f 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -95,19 +95,23 @@ def backward(ctx, grad_output, grad_residual): index_first_axis_residual = IndexFirstAxisResidual.apply -def unpad_input(hidden_states, attention_mask): +def unpad_input(hidden_states, attention_mask, unused_mask=None): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int + seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None. """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the @@ -120,6 +124,7 @@ def unpad_input(hidden_states, attention_mask): indices, cu_seqlens, max_seqlen_in_batch, + used_seqlens_in_batch, ) diff --git a/flash_attn/flash_blocksparse_attention.py b/flash_attn/flash_blocksparse_attention.py index 03798d16f..4c9302910 100644 --- a/flash_attn/flash_blocksparse_attention.py +++ b/flash_attn/flash_blocksparse_attention.py @@ -99,7 +99,7 @@ def forward( key_padding_mask_bool = key_padding_mask.bool_matrix nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_blocksparse_attn_func( x_unpad, diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 33d693520..6a78b1ea9 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -172,7 +172,7 @@ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): hidden_states = hidden_states[subset_mask] else: batch, seqlen = hidden_states.shape[:2] - hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( + hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input( hidden_states, key_padding_mask ) mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} diff --git a/hopper/epilogue_bwd_sm90_tma.hpp b/hopper/epilogue_bwd_sm90_tma.hpp index b6741120e..036ed09c2 100644 --- a/hopper/epilogue_bwd_sm90_tma.hpp +++ b/hopper/epilogue_bwd_sm90_tma.hpp @@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd { Element* ptr_dV; StridedKV const stride_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Device side kernel params @@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd { StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; static Params @@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd { select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, - tma_store_dK, tma_store_dV, args.cu_seqlens}; + tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); bool const is_varlen = params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : ( + params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] + ); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) @@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd { auto [n_block, bidh, bidb] = block_coord; bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) diff --git a/hopper/flash.h b/hopper/flash.h index bc58f1abc..24ca27f69 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; - // If provided, the actual length of each k sequence. + // If provided, the actual length of each q / o sequence. + int * __restrict__ seqused_q; + // If provided, the actual length of each k / v sequence. int * __restrict__ seqused_k; int *__restrict__ blockmask; @@ -116,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params { bool is_bf16; bool is_e4m3; bool is_causal; + bool is_local; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 27fcc5858..638752e4d 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, void *seqused_k, void *p_d, void *softmax_lse_d, @@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); params.seqused_k = static_cast(seqused_k); TORCH_CHECK( @@ -128,13 +130,18 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + window_size_left = std::min(int(seqlen_k), window_size_left); + window_size_right = std::min(int(seqlen_k), window_size_right); + if (window_size_left < 0) { window_size_left = seqlen_k; } + if (window_size_right < 0) { window_size_right = seqlen_k; } params.window_size_left = window_size_left; params.window_size_right = window_size_right; + params.is_causal = window_size_left == seqlen_k && window_size_right == 0; + if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) { + params.is_local = true; + } + #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), "This flash attention build does not support local attention."); @@ -171,6 +178,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, @@ -187,7 +196,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, - nullptr, + seqused_q, + seqused_k, nullptr, softmax_lse_d, p_dropout, @@ -268,7 +278,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size c10::optional &descale_q_, // 1 c10::optional &descale_k_, // 1 c10::optional &descale_v_, // 1 - bool is_causal) { + bool is_causal, + int window_size_left, + int window_size_right) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -345,6 +357,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -364,13 +378,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size q_padded, k_padded, v_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, /*seqused_k=*/nullptr, nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1); + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right); auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -426,11 +441,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. int max_seqlen_q, const int max_seqlen_k, const float softmax_scale, - bool is_causal) { + bool is_causal, + int window_size_left, + int window_size_right) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -461,10 +479,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int head_size_og = sizes[2]; const int num_heads_k = k.size(1); - int window_size_left = -1; - int window_size_right = -1; - if (is_causal) { window_size_right = 0; } - void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); const int total_q = q.sizes()[0]; @@ -473,15 +487,20 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (window_size_left >= max_seqlen_k) { window_size_left = -1; } - if (window_size_right >= max_seqlen_k) { window_size_right = -1; } - CHECK_SHAPE(q, total_q, num_heads, head_size_og); const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); @@ -520,6 +539,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -537,6 +558,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s q_padded, k_padded, v_padded, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, /*p_d=*/nullptr, softmax_lse.data_ptr(), @@ -604,6 +626,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const float softmax_scale, const bool is_causal, + int window_size_left, + int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -720,6 +744,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv_expanded = dv; } + if (is_causal) { window_size_right = 0; } + Flash_bwd_params params; set_params_dgrad(params, @@ -730,8 +756,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si head_size, head_size_rounded, q, k, v, out, dout_padded, dq, dk_expanded, dv_expanded, - nullptr, - nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, @@ -741,8 +769,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, deterministic); params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); @@ -787,10 +815,14 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float softmax_scale, const bool is_causal, + int window_size_left, + int window_size_right, const bool deterministic) { #ifdef FLASHATTENTION_DISABLE_BACKWARD @@ -854,7 +886,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(dout, total_q, num_heads, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } at::Tensor dq, dk, dv; if (dq_.has_value()) { @@ -892,6 +939,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x dout_padded = dout; } + if (is_causal) { window_size_right = 0; } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)q.get_device()}; @@ -927,6 +976,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x dout_padded, dq, dk_expanded, dv_expanded, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, @@ -936,8 +987,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1, + /*window_size_left=*/window_size_left, + /*window_size_right=*/window_size_right, deterministic); params.total_q = total_q; params.total_k = total_k; diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 29f66c99f..fffa69c94 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None): +def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -26,6 +26,8 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descal descale_k, descale_v, causal, + window_size[0], + window_size[1], ) return out, q, k, v, out_padded, softmax_lse, S_dmask @@ -42,6 +44,7 @@ def _flash_attn_backward( dv, softmax_scale, causal, + window_size, deterministic=False ): # dq, dk, dv are allocated by us so they should already be contiguous @@ -58,6 +61,8 @@ def _flash_attn_backward( dv, softmax_scale, causal, + window_size[0], + window_size[1], deterministic, ) return dq, dk, dv, softmax_d @@ -72,6 +77,9 @@ def _flash_attn_varlen_forward( max_seqlen_k, softmax_scale, causal, + window_size=(-1, -1), + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -82,11 +90,14 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, - None, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, + window_size[0], + window_size[1], ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -109,7 +120,10 @@ def _flash_attn_varlen_backward( max_seqlen_k, softmax_scale, causal, + window_size, deterministic=False, + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous @@ -132,10 +146,14 @@ def _flash_attn_varlen_backward( dv, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, + window_size[0], + window_size[1], deterministic, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): @@ -152,6 +170,7 @@ def forward( v, softmax_scale, causal, + window_size, deterministic=False, ): if softmax_scale is None: @@ -161,11 +180,13 @@ def forward( k, v, softmax_scale, - causal + causal, + window_size ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.deterministic = deterministic return out, softmax_lse @@ -185,12 +206,13 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, + ctx.window_size, ctx.deterministic, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None + return dq, dk, dv, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -206,7 +228,10 @@ def forward( max_seqlen_k, softmax_scale, causal, + window_size, deterministic=False, + seqused_q=None, + seqused_k=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -220,20 +245,25 @@ def forward( max_seqlen_k, softmax_scale, causal=causal, + window_size=window_size, + seqused_q=seqused_q, + seqused_k=seqused_k, ) ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.window_size = window_size ctx.deterministic = deterministic return out, softmax_lse @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, @@ -251,12 +281,15 @@ def backward(ctx, dout, *args): ctx.max_seqlen_k, ctx.softmax_scale, ctx.causal, + ctx.window_size, ctx.deterministic, + seqused_q, + seqused_k, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None def flash_attn_func( @@ -265,6 +298,7 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + window_size=(-1, -1), deterministic=False ): """dropout_p should be set to 0.0 during evaluation @@ -321,6 +355,7 @@ def flash_attn_func( v, softmax_scale, causal, + window_size, deterministic, ) @@ -335,7 +370,10 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale=None, causal=False, + window_size=(-1, -1), deterministic=False, + seqused_q=None, + seqused_k=None, ): """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -366,6 +404,11 @@ def flash_attn_varlen_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + query and output tokens in each sequence. + seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + key and value tokens in each sequence. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -382,5 +425,8 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + window_size, deterministic, + seqused_q, + seqused_k, ) diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h index f4ba0ff47..63cfd7892 100644 --- a/hopper/flash_bwd_kernel.h +++ b/hopper/flash_bwd_kernel.h @@ -31,6 +31,7 @@ class FlashAttnBwd { // Type Aliases static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; + static constexpr bool Is_local = CollectiveMainloop_::Is_local; static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); static constexpr bool Varlen = CollectiveMainloop_::Varlen; @@ -155,6 +156,7 @@ class FlashAttnBwd { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -218,14 +220,14 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; } } - if constexpr (Is_causal) { + if constexpr (Is_causal || Is_local) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); if (m_block_min >= m_block_max) { scheduler.prefetch_next_work(params.scheduler, work_tile_info); continue; @@ -247,13 +249,13 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } - } - if constexpr (Is_causal) { - int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); - if (m_block_min >= m_block_max) { continue; } + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } + // if constexpr (Is_causal || Is_local) { + // int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); + // int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); + // if (m_block_min >= m_block_max) { continue; } + // } collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } @@ -277,11 +279,14 @@ class FlashAttnBwd { auto block_coord = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb] = block_coord; if constexpr (Varlen) { - if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } + if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; } } - if constexpr (Is_causal) { + if constexpr (Is_causal || Is_local) { int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb); - int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM); + int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb); + auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb); + auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb); + auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM); if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); continue; @@ -300,7 +305,6 @@ class FlashAttnBwd { } collective_epilogue.store_tail(); } - } }; diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index f8ef642ef..2fe165584 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -20,7 +20,7 @@ using namespace cute; -template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using TileShape_MK = cute::Shape, Int>; @@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ params.b, params.dq_semaphore, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); @@ -56,7 +57,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape<_1, Int<1>, _1>; static constexpr int Stages = 2; using CollectiveMainloop = flash::CollectiveMainloopBwd; using CollectiveEpilogue = flash::CollectiveEpilogueBwd; using Scheduler = flash::SingleTileSchedulerBwd; @@ -87,6 +88,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k, + params.window_size_left, params.window_size_right }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.dk_ptr), @@ -146,7 +149,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ params.scale_softmax, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); @@ -167,9 +171,11 @@ template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); @@ -179,9 +185,11 @@ template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); @@ -191,9 +199,11 @@ template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { - BOOL_SWITCH(params.deterministic, Deterministic, [&] { - run_flash_bwd(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { + BOOL_SWITCH(params.deterministic, Deterministic, [&] { + run_flash_bwd(params, stream); + }); }); }); }); diff --git a/hopper/flash_bwd_postprocess_kernel.h b/hopper/flash_bwd_postprocess_kernel.h index 3c54647d0..f31912add 100644 --- a/hopper/flash_bwd_postprocess_kernel.h +++ b/hopper/flash_bwd_postprocess_kernel.h @@ -102,6 +102,7 @@ class FlashAttnBwdPostprocessConvertdQ { StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -113,6 +114,7 @@ class FlashAttnBwdPostprocessConvertdQ { StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -133,7 +135,8 @@ class FlashAttnBwdPostprocessConvertdQ { args.shape_dQ, args.stride_dQ, args.softmax_scale, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -156,7 +159,7 @@ class FlashAttnBwdPostprocessConvertdQ { int const bidb = blockIdx.z; bool const is_varlen = params.cu_seqlens != nullptr; - int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]); if (is_varlen && m_block * kBlockM >= seqlen) { return; } int lane_predicate = cute::elect_one_sync(); diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h index 5dc5e063c..86322ea79 100644 --- a/hopper/flash_bwd_preprocess_kernel.h +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -85,6 +85,7 @@ class FlashAttnBwdPreprocess { int num_batch; // We need this to know the size of dq_semaphore in case of varlen int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -107,6 +108,7 @@ class FlashAttnBwdPreprocess { int num_batch; int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -131,7 +133,8 @@ class FlashAttnBwdPreprocess { args.stride_dQaccum, args.num_batch, args.dq_semaphore, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -148,7 +151,7 @@ class FlashAttnBwdPreprocess { bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : params.cu_seqlens[bidb + 1] - offset_o; + int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o); if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 5d79ddeb9..9517c5e0c 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -24,9 +24,9 @@ namespace flash { using namespace cute; -template +template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k @@ -47,7 +47,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, // static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; - using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveMainloop = CollectiveMainloopFwd; using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -121,9 +121,11 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { continue; } - int n_block_max = collective_mainloop.get_n_block_max( + const int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = collective_mainloop.get_n_block_min( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { + if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); continue; @@ -167,15 +169,17 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { continue; } - int n_block_max = collective_mainloop.get_n_block_max( + const int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = collective_mainloop.get_n_block_min( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + if ((Is_causal || Is_local || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= n_block_min) { // We exit early and write 0 to gO and -inf to gLSE. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); continue; } collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, - tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, + tOrO, softmax, n_block_max, n_block_min, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, seqlen_traits_q, seqlen_traits_k); // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, @@ -190,7 +194,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k @@ -215,7 +219,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128; static constexpr bool Use_max_offset = true; - using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveMainloop = CollectiveMainloopFwd; using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 9b5979763..2ed052162 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -18,7 +18,7 @@ #include "utils.h" -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using Element = typename Kernel_traits::Element; using OutputType = typename Kernel_traits::OutputType; @@ -26,10 +26,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = typename Kernel_traits::ClusterShape_MNK; // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); - using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveMainloop = flash::CollectiveMainloopFwd; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; using Scheduler = std::conditional_t< - Seqlen_traits::kUseVarSeqLen, + Seqlen_traits::kUseVarSeqLen || Is_local, flash::SingleTileScheduler, std::conditional_t>; // using Scheduler = flash::SingleTileScheduler; Seqlen_traits seqlen_traits_q( - params.total_q, params.seqlen_q, params.cu_seqlens_q); + params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q); Seqlen_traits seqlen_traits_k( params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); typename CollectiveMainloop::Params mainloop_params = @@ -60,7 +60,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.scale_softmax_log2, params.descale_q_ptr, params.descale_k_ptr, - params.descale_v_ptr + params.descale_v_ptr, + params.window_size_left, + params.window_size_right }); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ @@ -85,7 +87,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if constexpr(cutlass::sizeof_bits_v == 8) kernel = (void *)flash::compute_attn_ws_fp8; else - kernel = (void *)flash::compute_attn_ws; + kernel = (void *)flash::compute_attn_ws; int smem_size = sizeof(typename Kernel_traits::SharedStorage); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); @@ -115,11 +117,13 @@ template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); } @@ -128,13 +132,15 @@ template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); }); @@ -144,13 +150,15 @@ template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { - run_flash_fwd< - Flash_fwd_kernel_traits, - Is_causal, Seqlen_traits - >(params, stream); + BOOL_SWITCH(params.is_local, Is_local, [&] { + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Is_local, Seqlen_traits + >(params, stream); + }); }); }); }); @@ -166,11 +174,11 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -195,11 +203,11 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -224,11 +232,11 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) { using Seqlen_traits = flash::FixedSeqLenTraits; if(params.is_causal) { run_flash_fwd, /*Is_causal=*/true, Seqlen_traits>(params, stream); + false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream); } else { BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] { run_flash_fwd, /*Is_causal=*/false, Seqlen_traits>(params, stream); + false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream); }); } // BOOL_SWITCH(params.is_causal, Is_causal, [&] { diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 7483f4efd..334cdcacd 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -24,7 +24,7 @@ namespace flash { using namespace cute; template struct CollectiveMainloopBwd { @@ -36,6 +36,7 @@ struct CollectiveMainloopBwd { using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; + static constexpr bool Is_local = Is_local_; static constexpr bool Varlen = Varlen_; static constexpr bool SdP_swapAB = true; static constexpr bool dKV_swapAB = dKV_swapAB_; @@ -279,6 +280,10 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_k = nullptr; + int const* seqused_v = nullptr; + int window_size_left; + int window_size_right; }; // Device side kernel params @@ -303,6 +308,10 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_q = nullptr; + int const* seqused_k = nullptr; + int window_size_left; + int window_size_right; }; static Params @@ -362,7 +371,8 @@ struct CollectiveMainloopBwd { tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, float(args.softmax_scale * M_LOG2E), - args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k}; + args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, + args.seqused_k, args.seqused_v, args.window_size_left, args.window_size_right}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -384,7 +394,10 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_q == nullptr ? get<0>(params.shape_Q) - : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]; + : (params.seqused_q + ? params.seqused_q[bidb] + : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] + ); } } @@ -395,21 +408,40 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_k == nullptr ? get<0>(params.shape_K) - : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]; + : (params.seqused_k + ? params.seqused_k[bidb] + : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] + ); } } CUTLASS_DEVICE int get_m_block_min(Params const& params, int n_block, int bidb) { - if constexpr (Is_causal) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + if constexpr (Is_causal || Is_local) { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); - return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM); + return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); } else { return 0; } } + CUTLASS_DEVICE + int get_m_block_max(Params const& params, int n_block, int bidb) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = get_seqlen_q(params, bidb); + int const seqlen_k = get_seqlen_k(params, bidb); + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_local) { + return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); + } else { + return m_block_max; + } + } + template CUTLASS_DEVICE void load(Params const& params, @@ -480,7 +512,7 @@ struct CollectiveMainloopBwd { } } - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; @@ -557,7 +589,7 @@ struct CollectiveMainloopBwd { Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; int const num_batch = params.num_batch; @@ -581,6 +613,15 @@ struct CollectiveMainloopBwd { } cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to } + if constexpr (Deterministic) { + constexpr int kBlockM = get<0>(TileShape_MNK{}); + int const seqlen_q = get_seqlen_q(params, bidb); + int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM); + #pragma unroll 2 + for (; m_block < m_block_global_max; ++m_block) { + Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); + } + } } CUTLASS_DEVICE void @@ -667,7 +708,7 @@ struct CollectiveMainloopBwd { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); - int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{})); + int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; @@ -732,8 +773,8 @@ struct CollectiveMainloopBwd { int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset, - seqlen_k - n_block * kBlockN)) { + if (int(get<0>(taccScS(i))) >= + std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } @@ -789,10 +830,23 @@ struct CollectiveMainloopBwd { warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + if constexpr (!Is_local) { + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } + } else { + int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right; + int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) || + (int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) { + tSrS(i) = -INFINITY; + } + } } + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); } @@ -838,4 +892,3 @@ struct CollectiveMainloopBwd { }; } // namespace flash - diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 8094ad3ab..111421580 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -79,7 +79,7 @@ struct SmemTransposeFp8_64x64 { } }; -template +template struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; @@ -158,6 +158,8 @@ struct CollectiveMainloopFwd { float const* descale_q_ptr; float const* descale_k_ptr; float const* descale_v_ptr; + int window_size_left; + int window_size_right; }; // Device side kernel params @@ -173,6 +175,8 @@ struct CollectiveMainloopFwd { float const* descale_q_ptr; float const* descale_k_ptr; float const* descale_v_ptr; + int window_size_left; + int window_size_right; }; @@ -203,7 +207,8 @@ struct CollectiveMainloopFwd { cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2, - args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr}; + args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr, + args.window_size_left, args.window_size_right}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -225,13 +230,34 @@ struct CollectiveMainloopFwd { int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN)); + if constexpr (Is_causal || Is_local) { + n_block_max = std::min( + n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q + mainloop_params.window_size_right, kBlockN)); } return n_block_max; } + CUTLASS_DEVICE + int get_n_block_min( + Params const& mainloop_params, int m_block, + const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k + ) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q); + int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K); + if constexpr (!Is_local) { + return 0; + } else { + return std::max( + 0, + (m_block * kBlockM + seqlen_k - seqlen_q - mainloop_params.window_size_left) / kBlockN + ); + } + } + template CUTLASS_DEVICE void load(Params const& mainloop_params, @@ -288,7 +314,8 @@ struct CollectiveMainloopFwd { } } - int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_min = get_n_block_min(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); + const int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -315,7 +342,7 @@ struct CollectiveMainloopFwd { if (lane_predicate) { // CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 2 - for (; n_block > 0; --n_block) { + for (; n_block > n_block_min; --n_block) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index())); @@ -645,6 +672,7 @@ struct CollectiveMainloopFwd { FrgTensorO& tOrO, Softmax& softmax, int n_block_count, + int n_block_min, int thread_idx, int work_idx, int m_block, @@ -706,37 +734,50 @@ struct CollectiveMainloopFwd { pipeline_k.consumer_release(smem_pipe_read_k); ++smem_pipe_read_k; - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + auto col_limit_right = [&](int row, int n_block) { + return std::min( + seqlen_k - n_block * kBlockN, + row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + mainloop_params.window_size_right + ); + }; + auto col_limit_left = [&](int row, int n_block) { + return std::max( + 0, + row + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - mainloop_params.window_size_left + ); }; { Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col + if constexpr (!Is_causal && !Is_local) { // Just masking based on col if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } else { // mask based on both row and col // using std::min is faster than doing col >= limit0 or col >= limit1 // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the // right hand side can be negative and might be converted to a very large unsigned integer. - if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, - col_limit_causal(int(get<0>(tScS(i))), n_block))) { + if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block)) { tSrS(i) = -INFINITY; + } else if constexpr (Is_local) { + if (int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block)) { + tSrS(i) = -INFINITY; + } } } } } softmax.template online_softmax(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); Tensor scores_scale = make_fragment_like(softmax.row_max); clear(scores_scale); - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + constexpr int n_masking_steps = (!Is_causal) ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) { + for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > n_block_min; ++masking_step, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); @@ -751,7 +792,7 @@ struct CollectiveMainloopFwd { Tensor tScS = threadMma0.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) { + if (int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1)) { tSrS(i) = -INFINITY; } } @@ -765,7 +806,7 @@ struct CollectiveMainloopFwd { } #pragma unroll 1 - for (; n_block > 0; --n_block) { + for (; n_block > n_block_min; --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); warp_scheduler_barrier_sync(); @@ -776,9 +817,24 @@ struct CollectiveMainloopFwd { warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K + + if constexpr(Is_local) { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if ( + int(get<1>(tScS(i))) >= col_limit_right(int(get<0>(tScS(i))), n_block - 1) || + int(get<1>(tScS(i))) < col_limit_left(int(get<0>(tScS(i))), n_block - 1) + ) { + tSrS(i) = -INFINITY; + } + } + } // auto scores_scale = softmax.template max(tSrS); - cute::copy(softmax.template max(tSrS), scores_scale); - softmax.template online_softmax(tSrS); + cute::copy(softmax.template max(tSrS), scores_scale); + softmax.template online_softmax(tSrS); + warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; @@ -791,11 +847,10 @@ struct CollectiveMainloopFwd { softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - cute::copy(softmax.template finalize(tSrS), scores_scale); + cute::copy(softmax.template finalize(tSrS), scores_scale); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang ++smem_pipe_read_v; - softmax.rescale_o(tOrO, scores_scale); return; } diff --git a/hopper/setup.py b/hopper/setup.py index 12e86667c..2a23d1657 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -144,7 +144,7 @@ def append_nvcc_threads(nvcc_extra_args): "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers "-lineinfo", "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging - "-DNDEBUG", # Important, otherwise performance is severely impacted + "-DNDEBUG", # Important, otherwise performance is severely impacted ] include_dirs = [ # Path(this_dir) / "fmha-pipeline", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d66c65372..45ee8d4bf 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -24,11 +24,14 @@ def print_diffs(out, out_ref): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -37,7 +40,7 @@ def print_diffs(out, out_ref): # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) -# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @@ -45,7 +48,7 @@ def print_diffs(out, out_ref): "seqlen_q,seqlen_k", [ (1, 1), - (257, 1), + # (257, 1), (64, 128), (128, 128), (256, 256), @@ -65,13 +68,13 @@ def print_diffs(out, out_ref): ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype, descale + seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: - dtype_init = dtype + dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) @@ -80,9 +83,11 @@ def test_flash_attn_output( batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # nheads_kv = 2 - # batch_size = 9 - # nheads = 6 + # nheads_kv = 1 + # batch_size = 1 + # nheads = 1 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + print(f"window_size: {window_size}", flush=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) @@ -96,7 +101,7 @@ def test_flash_attn_output( descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): - out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic) + out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v @@ -113,7 +118,7 @@ def test_flash_attn_output( q = q * descale_q k = k * descale_k v = v * descale_v - + out_ref, attn_ref = attention_ref( q, k, @@ -121,6 +126,7 @@ def test_flash_attn_output( None, None, causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -129,6 +135,7 @@ def test_flash_attn_output( None, None, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -144,9 +151,9 @@ def test_flash_attn_output( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - + # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: @@ -181,7 +188,7 @@ def test_flash_attn_output( # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 - else: + else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() @@ -196,12 +203,16 @@ def test_flash_attn_output( @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [128]) +# @pytest.mark.parametrize('d', [256]) # @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128]) @@ -231,7 +242,7 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, deterministic, add_unused_qkv, mha_type, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -241,12 +252,15 @@ def test_flash_attn_varlen_output( device = "cuda" # set seed torch.random.manual_seed(0) - # batch_size = 1 - # nheads = 1 batch_size = 9 - nheads = 6 - nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - + nheads = 4 + nheads_kv = 4 + # batch_size = 9 + # nheads = 6 + # nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True @@ -259,12 +273,27 @@ def test_flash_attn_varlen_output( key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True) # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device) + key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device) + ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q, @@ -273,7 +302,7 @@ def test_flash_attn_varlen_output( output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) # print("cu_seqlens_q: ", cu_seqlens_q) # print("cu_seqlens_k: ", cu_seqlens_k) # print("q_unpad, shape: ", q_unpad.shape) @@ -289,8 +318,14 @@ def test_flash_attn_varlen_output( max_seqlen_k, causal=causal, deterministic=deterministic, + seqused_q=seqused_q, + seqused_k=seqused_k, + window_size=window_size, ) out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + out.masked_fill_(q_zero_masking, 0.0) dropout_mask = None out_ref, attn_ref = attention_ref( @@ -300,6 +335,7 @@ def test_flash_attn_varlen_output( query_padding_mask, key_padding_mask, causal=causal, + window_size=window_size, ) out_pt, attn_pt = attention_ref( q, @@ -308,6 +344,7 @@ def test_flash_attn_varlen_output( query_padding_mask, key_padding_mask, causal=causal, + window_size=window_size, upcast=False, reorder_ops=True, ) @@ -326,6 +363,10 @@ def test_flash_attn_varlen_output( ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) ( dq_ref, dk_ref, @@ -342,6 +383,7 @@ def test_flash_attn_varlen_output( dk_pt.masked_fill_(zero_masking, 0.0) dv_pt.masked_fill_(zero_masking, 0.0) dq = dq_pad_fn(dq_unpad) + dq.masked_fill_(q_zero_masking, 0.0) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index ac93ca94b..2fbb417e4 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -270,4 +270,4 @@ class DynamicPersistentTileScheduler { }; -} // flash \ No newline at end of file +} // flash diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 72d55134e..d5bb6ba85 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -89,7 +89,7 @@ def generate_qkv( assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -104,8 +104,8 @@ def generate_qkv( ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 6f2a5fae7..f6b3e5aed 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -215,7 +215,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of x_pt = x.detach().clone().requires_grad_() lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths - x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask) + x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask) x_unpad_clone = x_unpad.clone() x_unpad = x_unpad.requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) diff --git a/tests/test_util.py b/tests/test_util.py index ebd7183f1..0802ca2e8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -29,7 +29,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False + q, k, v, query_padding_mask=None, key_padding_mask=None, + kvpacked=False, qkvpacked=False, add_unused_qkv=False, + query_unused_mask=None, key_unused_mask=None, ): """ Arguments: @@ -44,9 +46,14 @@ def generate_qkv( _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask, + ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -55,20 +62,22 @@ def generate_qkv( cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) + seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) + seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: @@ -125,6 +134,8 @@ def generate_qkv( v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(),