Skip to content

Commit

Permalink
[Core] Flashinfer - Remove advance step size restriction (#10282)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanimajety authored Nov 13, 2024
1 parent 1b886aa commit b6dde33
Showing 1 changed file with 38 additions and 28 deletions.
66 changes: 38 additions & 28 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t,
}
}

/// each thread processes a block per query
__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
Expand Down Expand Up @@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;

// Update paged_kv_indptr
if (idx == 0) {
paged_kv_indptr_ptr[idx] = 0;
}
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
Expand All @@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}

__global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
int row = idx / block_tables_stride;
int col = idx % block_tables_stride;

if (row < num_queries && col < block_table_bound_ptr[row]) {
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
block_tables_ptr[row * block_tables_stride + col];
// note: max_num_blocks_per_seq = block_tables.stride(0)
int tid = blockIdx.x * blockDim.x + threadIdx.x;

// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if (num_queries < tid && tid <= num_seqs) {
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) {
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];

// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
idx += (gridDim.x * blockDim.x)) {
// block_tables-row = paged_kv_indptr[queryNum]
int queryNum = idx / max_num_blocks_per_seq;
int col = idx % max_num_blocks_per_seq;
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
paged_kv_indices_ptr[indices_arr_idx] =
block_tables_ptr[block_tables_idx];
}
}
}

Expand Down Expand Up @@ -247,22 +263,16 @@ void advance_step_flashinfer(
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}

// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,",
" increasing the block size or take smaller steps.",
" num_queries = ", num_queries,
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
int block_tables_stride = block_tables.stride(0);
TORCH_CHECK((blocks * threads > num_queries),
"multi-step: not enough threads to map to num_queries = ",
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
if (logging) {
printf("launching kernels with %d blocks and %d threads\n", blocks,
threads);
}

advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
Expand All @@ -281,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast<int*>(block_table_bound.data_ptr()));

advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
Expand Down

0 comments on commit b6dde33

Please sign in to comment.