diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 46fef79f439fb..bd184ee22682e 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t, } } +/// each thread processes a block per query __global__ void advance_step_flashinfer_kernel( int num_threads, int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, long const* sampled_token_ids_ptr, @@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel( int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { int idx = blockIdx.x * num_threads + threadIdx.x; - // Update paged_kv_indptr + if (idx == 0) { + paged_kv_indptr_ptr[idx] = 0; + } if (idx < num_queries) { int sum = 0; for (int i = 0; i <= idx; ++i) { @@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel( } __global__ void advance_step_flashinfer_indices_kernel( - int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, - int64_t const block_tables_stride, int* paged_kv_indices_ptr, + int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { - int idx = blockIdx.x * num_threads + threadIdx.x; - int row = idx / block_tables_stride; - int col = idx % block_tables_stride; - - if (row < num_queries && col < block_table_bound_ptr[row]) { - paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = - block_tables_ptr[row * block_tables_stride + col]; + // note: max_num_blocks_per_seq = block_tables.stride(0) + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // when cuda graphs are enabled, paged_kv_indptr tensor + // has to be updated for the padded queries + // tid represents a query# for paged_kv_indptr tensor + if (num_queries < tid && tid <= num_seqs) { + paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; } - // if cudagraph, fill padded seqs with the last valid seq's indptr - if (num_queries < row && row <= num_seqs) { - paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; + + // each thread processes a block_ptr in block_tables + // block_tables shape: [num_queries, max_num_blocks_per_seq] + // paged_kv_indices is flattened block_tables. + for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); + idx += (gridDim.x * blockDim.x)) { + // block_tables-row = paged_kv_indptr[queryNum] + int queryNum = idx / max_num_blocks_per_seq; + int col = idx % max_num_blocks_per_seq; + if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { + int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; + int block_tables_idx = queryNum * max_num_blocks_per_seq + col; + paged_kv_indices_ptr[indices_arr_idx] = + block_tables_ptr[block_tables_idx]; + } } } @@ -247,22 +263,16 @@ void advance_step_flashinfer( int threads; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); - if (logging) { - printf("launching kernel with %d blocks\n", blocks); - } - // TODO(will): support arbitrary block_tables stride - if ((blocks * threads) / block_tables.stride(0) < num_queries) { - TORCH_CHECK(false, - "multi-step: not enough threads to map block_table to" - "FlashInfer's paged_kv_indices on GPU. Try reducing the number " - "of seqs,", - " increasing the block size or take smaller steps.", - " num_queries = ", num_queries, - " block_tables.stride(0) = ", block_tables.stride(0), - " blocks = ", blocks, " max_threads = ", threads); + int block_tables_stride = block_tables.stride(0); + TORCH_CHECK((blocks * threads > num_queries), + "multi-step: not enough threads to map to num_queries = ", + num_queries, " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + if (logging) { + printf("launching kernels with %d blocks and %d threads\n", blocks, + threads); } - advance_step_flashinfer_kernel<<>>( threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), @@ -281,7 +291,7 @@ void advance_step_flashinfer( reinterpret_cast(block_table_bound.data_ptr())); advance_step_flashinfer_indices_kernel<<>>( - threads, num_seqs, num_queries, + num_seqs, num_queries, reinterpret_cast(block_tables.data_ptr()), block_tables.stride(0), reinterpret_cast(paged_kv_indices.data_ptr()),