Skip to content

Commit

Permalink
hopper local attention
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy committed Sep 3, 2024
1 parent 65e6a2c commit 223b148
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 106 deletions.
1 change: 1 addition & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
bool is_bf16;
bool is_e4m3;
bool is_causal;
bool is_local;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
Expand Down
41 changes: 23 additions & 18 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ void set_params_fprop(Flash_fwd_params &params,
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;

if ((window_size_left >= 0 || window_size_right >= 0) && !params.is_causal) {
params.is_local = true;
}
window_size_left = std::min(int(seqlen_k), window_size_left);
window_size_right = std::min(int(seqlen_k), window_size_right);
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
Expand Down Expand Up @@ -273,7 +277,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &descale_q_, // 1
c10::optional<at::Tensor> &descale_k_, // 1
c10::optional<at::Tensor> &descale_v_, // 1
bool is_causal) {
bool is_causal,
int window_size_left,
int window_size_right) {

auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
Expand Down Expand Up @@ -375,8 +381,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/-1,
/*window_size_right=*/is_causal ? 0 : -1);
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right);

auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
Expand Down Expand Up @@ -437,7 +443,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int max_seqlen_q,
const int max_seqlen_k,
const float softmax_scale,
bool is_causal) {
bool is_causal,
int window_size_left,
int window_size_right) {

auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
Expand Down Expand Up @@ -468,10 +476,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);

int window_size_left = -1;
int window_size_right = -1;
if (is_causal) { window_size_right = 0; }

void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();

const int total_q = q.sizes()[0];
Expand All @@ -480,9 +484,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }

CHECK_SHAPE(q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
Expand Down Expand Up @@ -558,8 +559,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right,
is_causal ? -1 : window_size_left,
is_causal ? 0 : window_size_right,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/true);
params.total_q = total_q;
Expand Down Expand Up @@ -620,6 +621,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
const float softmax_scale,
const bool is_causal,
const int window_size_left,
const int window_size_right,
const bool deterministic) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
Expand Down Expand Up @@ -759,8 +762,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/-1,
/*window_size_right=*/is_causal ? 0 : -1,
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right,
deterministic);
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();

Expand Down Expand Up @@ -811,6 +814,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
const int max_seqlen_k, // max sequence length to choose the kernel
const float softmax_scale,
const bool is_causal,
const int window_size_left,
const int window_size_right,
const bool deterministic) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
Expand Down Expand Up @@ -973,8 +978,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
softmax_d.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/-1,
/*window_size_right=*/is_causal ? 0 : -1,
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right,
deterministic);
params.total_q = total_q;
params.total_k = total_k;
Expand Down
32 changes: 28 additions & 4 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
Expand All @@ -26,6 +26,8 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descal
descale_k,
descale_v,
causal,
window_size[0],
window_size[1],
)
return out, q, k, v, out_padded, softmax_lse, S_dmask

Expand All @@ -42,6 +44,7 @@ def _flash_attn_backward(
dv,
softmax_scale,
causal,
window_size,
deterministic=False
):
# dq, dk, dv are allocated by us so they should already be contiguous
Expand All @@ -58,6 +61,8 @@ def _flash_attn_backward(
dv,
softmax_scale,
causal,
window_size[0],
window_size[1],
deterministic,
)
return dq, dk, dv, softmax_d
Expand All @@ -72,6 +77,7 @@ def _flash_attn_varlen_forward(
max_seqlen_k,
softmax_scale,
causal,
window_size=(-1, -1),
seqused_q=None,
seqused_k=None,
):
Expand All @@ -90,6 +96,8 @@ def _flash_attn_varlen_forward(
max_seqlen_k,
softmax_scale,
causal,
window_size[0],
window_size[1],
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
Expand All @@ -112,6 +120,7 @@ def _flash_attn_varlen_backward(
max_seqlen_k,
softmax_scale,
causal,
window_size,
deterministic=False,
seqused_q=None,
seqused_k=None,
Expand Down Expand Up @@ -143,6 +152,8 @@ def _flash_attn_varlen_backward(
max_seqlen_k,
softmax_scale,
causal,
window_size[0],
window_size[1],
deterministic,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
Expand All @@ -159,6 +170,7 @@ def forward(
v,
softmax_scale,
causal,
window_size,
deterministic=False,
):
if softmax_scale is None:
Expand All @@ -168,11 +180,13 @@ def forward(
k,
v,
softmax_scale,
causal
causal,
window_size
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.deterministic = deterministic
return out, softmax_lse

Expand All @@ -192,6 +206,7 @@ def backward(ctx, dout, *args):
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.deterministic,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
Expand All @@ -213,6 +228,7 @@ def forward(
max_seqlen_k,
softmax_scale,
causal,
window_size,
deterministic=False,
seqused_q=None,
seqused_k=None,
Expand All @@ -229,6 +245,7 @@ def forward(
max_seqlen_k,
softmax_scale,
causal=causal,
window_size=window_size,
seqused_q=seqused_q,
seqused_k=seqused_k,
)
Expand All @@ -240,6 +257,7 @@ def forward(
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.deterministic = deterministic
return out, softmax_lse

Expand All @@ -263,6 +281,7 @@ def backward(ctx, dout, *args):
ctx.max_seqlen_k,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.deterministic,
seqused_q,
seqused_k,
Expand All @@ -279,6 +298,7 @@ def flash_attn_func(
v,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False
):
"""dropout_p should be set to 0.0 during evaluation
Expand Down Expand Up @@ -335,6 +355,7 @@ def flash_attn_func(
v,
softmax_scale,
causal,
window_size,
deterministic,
)

Expand All @@ -349,6 +370,7 @@ def flash_attn_varlen_func(
max_seqlen_k,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
seqused_q=None,
seqused_k=None,
Expand Down Expand Up @@ -382,9 +404,10 @@ def flash_attn_varlen_func(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
query and output tokens in each sequence.
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
key and value tokens in each sequence.
Return:
out: (total, nheads, headdim).
Expand All @@ -402,6 +425,7 @@ def flash_attn_varlen_func(
max_seqlen_k,
softmax_scale,
causal,
window_size,
deterministic,
seqused_q,
seqused_k,
Expand Down
6 changes: 3 additions & 3 deletions hopper/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class FlashAttnBwd {
}
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) {
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
continue;
Expand All @@ -251,7 +251,7 @@ class FlashAttnBwd {
}
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) { continue; }
}
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
Expand Down Expand Up @@ -281,7 +281,7 @@ class FlashAttnBwd {
}
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = cute::ceil_div(collective_mainloop.get_seqlen_q(params.mainloop, bidb), kBlockM);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
continue;
Expand Down
28 changes: 17 additions & 11 deletions hopper/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

using namespace cute;

template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Varlen, bool Deterministic,
template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_causal, bool Is_local, bool Varlen, bool Deterministic,
bool dKV_swapAB, bool dQ_swapAB, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
Expand Down Expand Up @@ -57,7 +57,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
using ClusterShape = cute::Shape<_1, Int<1>, _1>;
static constexpr int Stages = 2;
using CollectiveMainloop = flash::CollectiveMainloopBwd<Stages, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
Is_causal, Varlen, Deterministic,
Is_causal, Is_local, Varlen, Deterministic,
dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>;
using CollectiveEpilogue = flash::CollectiveEpilogueBwd<TileShape_MNK, Element, CollectiveMainloop::NumMmaThreads, Varlen>;
using Scheduler = flash::SingleTileSchedulerBwd;
Expand Down Expand Up @@ -170,9 +170,11 @@ template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 128, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 2>(params, stream);
});
});
});
});
Expand All @@ -182,9 +184,11 @@ template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
});
});
});
});
Expand All @@ -194,9 +198,11 @@ template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
BOOL_SWITCH(params.is_local, Is_local, [&] {
BOOL_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
BOOL_SWITCH(params.deterministic, Deterministic, [&] {
run_flash_bwd<Headdim, 64, 128, T, Is_causal, Is_local, Varlen, Deterministic, false, false, 1, 2, 1>(params, stream);
});
});
});
});
Expand Down
Loading

0 comments on commit 223b148

Please sign in to comment.