Skip to content

Commit

Permalink
fix: fatal bugfix in batch decode operator (#177)
Browse files Browse the repository at this point in the history
The `BatchDecodeWithPagedKVCacheWrapper` didn't run into the kernel.
  • Loading branch information
yzh119 authored Mar 12, 2024
1 parent 44d3c03 commit 238563f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 54 deletions.
90 changes: 39 additions & 51 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,51 +166,15 @@ cudaError_t BatchDecodeWithPagedKVCache(
* \note This wrapper function should be only called after we call BeginForward function in the
* BatchDecodeHandler.
*/
template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();

if (handler->IsForwardStarted()) {
if (tmp != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}

return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, kv_layout,
pos_encoding_mode, DTypeIn, DTypeOut, IdType>(
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta,
stream);
return cudaSuccess;
}

template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
template <PageStorage page_storage, QKVLayout KV_LAYOUT, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
const uint32_t num_kv_heads = paged_kv.num_heads;
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand All @@ -219,18 +183,42 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
throw std::invalid_argument(err_msg.str());
}

// DISPATCH_GQA_GROUP_SIZE(
// num_qo_heads / num_kv_heads, GROUP_SIZE,
// {DISPATCH_HEAD_DIM(
// paged_kv.head_dim, HEAD_DIM,
// {DISPATCH_POS_ENCODING_MODE(
// pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
// return BatchDecodeWithPagedKVCacheWrapperDispatched<
// page_storage, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, DTypeIn,
// DTypeOut, IdType>(handler, q, q_offset, paged_kv, o, lse, sm_scale,
// rope_scale,
// rope_theta, stream);
// })})})});
DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{DISPATCH_HEAD_DIM(
paged_kv.head_dim, HEAD_DIM,
{DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, {
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();

if (handler->IsForwardStarted()) {
if (tmp != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition =
handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition =
handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}

return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
KV_LAYOUT, POS_ENCODING_MODE, DTypeIn,
DTypeOut, IdType>(
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale,
rope_theta, stream);
})})});
return cudaSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("page_size", [1, 8, 16])
@pytest.mark.parametrize("page_size", [1, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_batch_decode_with_shared_prefix_padded_kv_cache(
@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 4, 16])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_decode_with_shared_prefix_paged_kv_cache(
batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size
):
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_batch_decode_with_shared_prefix_paged_kv_cache(
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 4, 16])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_prefill_with_shared_prefix_paged_kv_cache(
batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size
):
Expand Down

0 comments on commit 238563f

Please sign in to comment.