Skip to content

Commit

Permalink
Change causal mask to be aligned to bottom-right instead of top-left
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 25, 2023
1 parent e07aa03 commit 9e5e8bc
Show file tree
Hide file tree
Showing 11 changed files with 573 additions and 204 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```
## Changes in v2.1 (compared to v2.0)

If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.

For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.

## Performance

Expand Down
24 changes: 10 additions & 14 deletions benchmarks/benchmark_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@

# from triton.ops.flash_attention import attention as attention_triton

try:
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
except ImportError:
fav2_qkvpacked_func = None
fav2_kvpacked_func = None
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func

try:
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
Expand Down Expand Up @@ -80,8 +75,8 @@ def attention_megatron(qkv):

torch.manual_seed(0)
repeats = 30
batch_size = 2
seqlen = 8192
batch_size = 8
seqlen = 2048
nheads = 12
headdim = 128
# nheads = 24
Expand All @@ -90,8 +85,8 @@ def attention_megatron(qkv):
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p = 0.1
causal = False
dropout_p = 0.0
causal = True
dtype = torch.float16
device = 'cuda'

Expand All @@ -100,20 +95,20 @@ def attention_megatron(qkv):
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)

# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)

# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)


# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
Expand Down Expand Up @@ -151,6 +146,7 @@ def attention_megatron(qkv):
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
ideal_a100_time = flops / 312 / 1e9
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
exit(0)


def time_fwd_bwd(func, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct BlockInfo {

const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
const int actual_seqlen_q;
const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
71 changes: 19 additions & 52 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM;
m_block_min = m_block_min < 0 ? 0 : m_block_min;

// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
if (Is_causal && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
// We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.

if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ);
Expand Down Expand Up @@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
}
Expand Down Expand Up @@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_q, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
Expand All @@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) {
uint32_t warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert(MMA_N_SdP % 2 == 0);
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
Expand Down Expand Up @@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi));
lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
}
Expand Down Expand Up @@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// the corresponding values of K would be 0, so the result would still be correct.
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
AtomLayoutMS * 16);
}
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) {
uint32_t warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert(MMA_N_SdP % 2 == 0);
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
Expand Down
Loading

0 comments on commit 9e5e8bc

Please sign in to comment.