diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index b15b47ca..fe914d49 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -419,21 +419,41 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in return cudaSuccess; } +inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) { + if (avg_packed_qo_len > 64 && head_dim < 256) { + return 128; + } else { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + // avg_packed_qo_len <= 64 + return 64; + } else { + // avg_packed_qo_len <= 16 + return 16; + } + } else { + // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout + return 64; + } + } +} + template -inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, uint32_t max_batch_size_if_split, - bool enable_cuda_graph) { +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t total_num_rows, uint32_t max_seq_len, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, + uint32_t max_batch_size_if_split, bool enable_cuda_graph) { std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; merge_indptr.push_back(0); o_indptr.push_back(0); const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - uint32_t total_num_rows = qo_indptr_h[batch_size]; - // step 1: compute qo_chunk_size + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); - int64_t sum_packed_qo_len = 0; for (uint32_t i = 0; i < batch_size; ++i) { packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); if (packed_qo_len_arr[i] < 0) { @@ -449,41 +469,43 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin << kv_indptr_h[i] << " should be non-negative"; FLASHINFER_ERROR(err_msg.str()); } - sum_packed_qo_len += packed_qo_len_arr[i]; } - int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q uint32_t cta_tile_q; - if (avg_packed_qo_len > 64 && head_dim < 256) { - cta_tile_q = 128; + uint32_t total_num_tiles_q; + bool split_kv; + int64_t kv_chunk_size, new_batch_size; + if (enable_cuda_graph) { + // When CUDA graphs are enabled, the lengths of sequences determined by + // qo_indptr_h can vary. We assume that the dummy data based on which + // the CUDA graph is created fixes the maximum number of tokens. + uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; + cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim); + + // Find an upper bound for the number of tiles, derived from the total + // number of rows and the batch size. The sum of qo lengths rounded + // up to cta_tile_q will not exceed this number derived from the total + // number of rows. + total_num_tiles_q = ceil_div(total_num_rows, cta_tile_q) + batch_size; + + split_kv = true; + kv_chunk_size = max_batch_size_if_split; + new_batch_size = max_batch_size_if_split; } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (avg_packed_qo_len > 16) { - // avg_packed_qo_len <= 64 - cta_tile_q = 64; - } else { - // avg_packed_qo_len <= 16 - cta_tile_q = 16; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - cta_tile_q = 64; + total_num_tiles_q = 0; + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); + sum_packed_qo_len += packed_qo_len_arr[i]; } - } - uint32_t total_num_tiles_q = 0; - for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - total_num_tiles_q += ceil_div(packed_qo_len_arr[request_idx], cta_tile_q); - } + const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim); - // step 2: determine kv_chunk_size - auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize( - max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, - /*min_kv_chunk_size=*/std::max((128 / page_size), 1U)); - - if (enable_cuda_graph) { - split_kv = total_num_tiles_q < max_batch_size_if_split; + std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize( + max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, + /*min_kv_chunk_size=*/std::max((128 / page_size), 1U)); } // step 3: split qo_indptr and kv_indptr @@ -511,7 +533,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin kv_chunk_size *= page_size; return std::make_tuple(split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, - total_num_rows, std::move(request_indices), std::move(qo_tile_indices), + std::move(request_indices), std::move(qo_tile_indices), std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); } @@ -597,9 +619,10 @@ template inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, - IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, + uint32_t max_seq_len, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -618,17 +641,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size - auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows, - request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, - o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, max_batch_size_if_split, enable_cuda_graph); + auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, + qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = + PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size, + max_batch_size_if_split, enable_cuda_graph); plan_info.cta_tile_q = cta_tile_q; plan_info.total_num_rows = total_num_rows; plan_info.enable_cuda_graph = enable_cuda_graph; size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; + plan_info.padded_batch_size = padded_batch_size; plan_info.split_kv = split_kv; @@ -679,6 +703,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr"); plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask"); + IdType* merge_indptr_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); bool* block_valid_mask_h = diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index cd466c72..c2490a7f 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -42,8 +42,9 @@ using namespace flashinfer; std::vector BatchPrefillWithKVCachePlan( unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) { + unsigned int total_num_rows, unsigned int max_seq_len, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -58,8 +59,8 @@ std::vector BatchPrefillWithKVCachePlan( float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), - kv_indptr.data_ptr(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + kv_indptr.data_ptr(), total_num_rows, max_seq_len, batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index c2c3c1eb..6f450393 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -99,8 +99,9 @@ void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at: std::vector BatchPrefillWithKVCachePlan( unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); + unsigned total_num_rows, unsigned int max_seq_len, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index b00fb280..d86c75b6 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -776,7 +776,9 @@ def plan( self._pin_memory_int_workspace_buffer, qo_indptr_host, indptr_host, + batch_size, # total_num_rows batch_size, + 1, # max_seq_len num_qo_heads, num_kv_heads, page_size, diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index e2fce665..c7bf3838 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -139,6 +139,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str: at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + unsigned int total_num_rows, + unsigned int max_seq_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, @@ -156,8 +158,9 @@ def paged_prefill_inst_templ(mask_mode: str) -> str: float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{dtype_idx}}>(), - kv_indptr.data_ptr<{{dtype_idx}}>(), batch_size, num_qo_heads, num_kv_heads, {{head_dim}}, - page_size, enable_cuda_graph, sizeof({{dtype_o}}), stream); + kv_indptr.data_ptr<{{dtype_idx}}>(), total_num_rows, max_seq_len, + batch_size, num_qo_heads, num_kv_heads, {{head_dim}}, page_size, + enable_cuda_graph, sizeof({{dtype_o}}), stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); @@ -457,6 +460,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str: at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + unsigned int total_num_rows, + unsigned int max_seq_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 469b3705..0da6468c 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -833,6 +833,8 @@ def __init__( self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf + self._max_total_num_rows = None + self._max_seq_len = None @property def is_cuda_graph_enabled(self) -> bool: @@ -993,7 +995,33 @@ def plan( bitorder="little", ) + # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors + qo_indptr_host = qo_indptr.to("cpu") + paged_kv_indptr_host = paged_kv_indptr.to("cpu") + + total_num_rows = qo_indptr_host[-1] + max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item() + if self.is_cuda_graph_enabled: + if self._max_total_num_rows is None: + self._max_total_num_rows = total_num_rows + elif total_num_rows > self._max_total_num_rows: + raise ValueError( + "The total number of rows in qo_indptr {} in cuda graph mode cannot " + "exceed the number of rows set during initialization {}.".format( + total_num_rows, self._max_total_num_rows + ) + ) + if self._max_seq_len is None: + self._max_seq_len = max_seq_len + elif max_seq_len > self._max_seq_len: + raise ValueError( + "The maximum sequence length in qo_indptr {} in cuda graph mode cannot " + "exceed the sequence length set during initialization {}.".format( + max_seq_len, self._max_seq_len + ) + ) + if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed during the lifecycle of the wrapper in " @@ -1049,10 +1077,6 @@ def plan( self.device, non_blocking=non_blocking ) - # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors - qo_indptr_host = qo_indptr.to("cpu") - paged_kv_indptr_host = paged_kv_indptr.to("cpu") - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._cached_module = get_batch_prefill_module( @@ -1073,6 +1097,8 @@ def plan( self._pin_memory_int_workspace_buffer, qo_indptr_host, paged_kv_indptr_host, + total_num_rows, + max_seq_len, batch_size, num_qo_heads, num_kv_heads, @@ -1463,6 +1489,7 @@ def __init__( self._kv_indptr_buf = kv_indptr_buf self._custom_mask_buf = custom_mask_buf self._qk_indptr_buf = qk_indptr_buf + self._max_total_num_rows = None @property def is_cuda_graph_enabled(self) -> bool: @@ -1610,7 +1637,33 @@ def plan( bitorder="little", ) + # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors + qo_indptr_host = qo_indptr.to("cpu") + paged_kv_indptr_host = paged_kv_indptr.to("cpu") + + total_num_rows = qo_indptr_host[-1] + max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item() + if self.is_cuda_graph_enabled: + if self._max_total_num_rows is None: + self._max_total_num_rows = total_num_rows + elif total_num_rows > self._max_total_num_rows: + raise ValueError( + "The total number of rows in qo_indptr {} in cuda graph mode cannot " + "exceed the number of rows set during initialization {}.".format( + total_num_rows, self._max_total_num_rows + ) + ) + if self._max_seq_len is None: + self._max_seq_len = max_seq_len + elif max_seq_len > self._max_seq_len: + raise ValueError( + "The maximum sequence length in qo_indptr {} in cuda graph mode cannot " + "exceed the sequence length set during initialization {}.".format( + max_seq_len, self._max_seq_len + ) + ) + if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " @@ -1638,10 +1691,6 @@ def plan( self._custom_mask_buf = packed_custom_mask.to(self.device) self._qk_indptr_buf = qk_indptr.to(self.device) - # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors - qo_indptr_host = qo_indptr.to("cpu") - kv_indptr_host = kv_indptr.to("cpu") - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._cached_module = get_batch_prefill_module( @@ -1662,6 +1711,8 @@ def plan( self._pin_memory_int_workspace_buffer, qo_indptr_host, kv_indptr_host, + total_num_rows, + max_seq_len, batch_size, num_qo_heads, num_kv_heads, diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 05f07f08..f17fd91f 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -144,11 +144,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - handler.Plan((void*)thrust::raw_pointer_cast(float_buffer.data()), - float_workspace_size_in_bytes, - (void*)thrust::raw_pointer_cast(int_buffer.data()), - int_workspace_size_in_bytes, qo_indptr_h.data(), kv_indptr_host.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.Plan( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_host.data(), /*total_num_rows=*/batch_size, /*max_seq_len=*/1, + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/bench_batch_prefill.cu b/src/bench_batch_prefill.cu index 802bbb1f..4d3cb976 100644 --- a/src/bench_batch_prefill.cu +++ b/src/bench_batch_prefill.cu @@ -76,7 +76,8 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { handler.Plan( thrust::raw_pointer_cast(float_workspace.data()), float_workspace_size_in_bytes, thrust::raw_pointer_cast(int_workspace.data()), int_workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + qo_indptr_h.data(), kv_indptr_h.data(), /*total_num_rows=*/batch_size * qo_len, + /*max_seq_len=*/qo_len, batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 94824e89..0f85b7dc 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -256,8 +256,10 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { cascade_handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); + qo_indptr_h.data(), kv_indptr_unique_h.data(), + /*total_num_rows=*/batch_size * qo_append_length, + /*max_seq_len=*/qo_append_length, batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -317,8 +319,10 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { baseline_handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); + qo_indptr_h.data(), kv_indptr_combined_h.data(), + /*total_num_rows=*/batch_size * qo_append_length, + /*max_seq_len=*/qo_append_length, batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 411165a5..4fb7ac6a 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -183,14 +183,16 @@ class BatchPrefillHandler { template cudaError_t Plan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, size_t int_workspace_size_in_bytes, IdType* qo_indptr_h, IdType* kv_indptr_h, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, uint32_t page_size) { + uint32_t total_num_rows, uint32_t max_seq_len, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size) { int_buffer_ = int_buffer; float_buffer_ = float_buffer; return PrefillPlan(float_buffer, float_workspace_size_in_bytes, int_buffer, page_locked_buffer_, int_workspace_size_in_bytes, plan_info_, - qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, enable_cuda_graph_, sizeof(DTypeO), stream_); + qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size, enable_cuda_graph_, + sizeof(DTypeO), stream_); } cudaStream_t GetCUDAStream() const { return stream_; } diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 6302a065..babc1d07 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -108,11 +108,11 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n thrust::device_vector q_device(q); thrust::device_vector o_device(q_len * num_qo_heads * head_dim); - handler.Plan((void*)thrust::raw_pointer_cast(float_buffer.data()), - float_workspace_size_in_bytes, - (void*)thrust::raw_pointer_cast(int_buffer.data()), - int_workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.Plan( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), /*total_num_rows=*/q_indptr.back(), + /*max_seq_len=*/q_len, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = @@ -154,7 +154,8 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo bool allow_fp16_qk_reduction) { uint32_t batch_size = 9; std::vector q_lens(batch_size), kv_lens(batch_size); - utils::vec_randint_(q_lens, 10, 15); + const uint32_t max_seq_len = 15; + utils::vec_randint_(q_lens, 10, max_seq_len); utils::vec_randint_(kv_lens, 128, 2048); std::vector append_indptr{0}, kv_indptr{0}; @@ -202,8 +203,8 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - append_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, - /*page_size=*/1); + append_indptr.data(), kv_indptr.data(), /*total_num_rows=*/append_indptr.back(), max_seq_len, + batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), @@ -245,9 +246,10 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si bool causal, PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { - uint32_t batch_size = 7; + const uint32_t batch_size = 7; + const uint32_t max_seq_len = 64; std::vector q_lens(batch_size); - utils::vec_randint_(q_lens, 1, 64); + utils::vec_randint_(q_lens, 1, max_seq_len); std::vector kv_lens(q_lens); std::vector q_indptr{0}; @@ -334,11 +336,11 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - handler.Plan((void*)thrust::raw_pointer_cast(float_buffer.data()), - float_workspace_size_in_bytes, - (void*)thrust::raw_pointer_cast(int_buffer.data()), - int_workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.Plan( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), /*total_num_rows=*/q_indptr.back(), max_seq_len, + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), @@ -463,11 +465,11 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - handler.Plan((void*)thrust::raw_pointer_cast(float_buffer.data()), - float_workspace_size_in_bytes, - (void*)thrust::raw_pointer_cast(int_buffer.data()), - int_workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + handler.Plan( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), /*total_num_rows=*/q_indptr.back(), q_len_max, batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), @@ -568,7 +570,8 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - append_indptr.data(), kv_indptr.data(), + append_indptr.data(), kv_indptr.data(), /*total_num_rows=*/append_indptr.back(), + /*max_seq_len=*/append_indptr.back(), /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 0cd83fa8..685f98c5 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -494,13 +494,15 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, baseline_handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); + qo_indptr_h.data(), kv_indptr_combined_h.data(), /*total_num_rows=*/qo_indptr_h.back(), + /*max_seq_len=*/qo_append_length, batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); cascade_handler.Plan( (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, - qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); + qo_indptr_h.data(), kv_indptr_unique_h.data(), /*total_num_rows=*/qo_indptr_h.back(), + /*max_seq_len=*/qo_append_length, batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()),