Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Flashinfer - Remove advance step size restriction #10282

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ inline void verify_tensor(std::string const& name, torch::Tensor const& t,
}
}

/// each thread processes a block per query
__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
Expand Down Expand Up @@ -134,8 +135,10 @@ __global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;

// Update paged_kv_indptr
if (idx == 0) {
paged_kv_indptr_ptr[idx] = 0;
}
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
Expand All @@ -146,20 +149,33 @@ __global__ void advance_step_flashinfer_indptr_kernel(
}

__global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
int row = idx / block_tables_stride;
int col = idx % block_tables_stride;

if (row < num_queries && col < block_table_bound_ptr[row]) {
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
block_tables_ptr[row * block_tables_stride + col];
// note: max_num_blocks_per_seq = block_tables.stride(0)
int tid = blockIdx.x * blockDim.x + threadIdx.x;

// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if (num_queries < tid && tid <= num_seqs) {
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) {
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];

// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
idx += (gridDim.x * blockDim.x)) {
// block_tables-row = paged_kv_indptr[queryNum]
int queryNum = idx / max_num_blocks_per_seq;
int col = idx % max_num_blocks_per_seq;
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
paged_kv_indices_ptr[indices_arr_idx] =
block_tables_ptr[block_tables_idx];
}
}
}

Expand Down Expand Up @@ -247,22 +263,16 @@ void advance_step_flashinfer(
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}

// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,",
" increasing the block size or take smaller steps.",
" num_queries = ", num_queries,
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
int block_tables_stride = block_tables.stride(0);
TORCH_CHECK((blocks * threads > num_queries),
"multi-step: not enough threads to map to num_queries = ",
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
if (logging) {
printf("launching kernels with %d blocks and %d threads\n", blocks,
threads);
}

advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
Expand All @@ -281,7 +291,7 @@ void advance_step_flashinfer(
reinterpret_cast<int*>(block_table_bound.data_ptr()));

advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
Expand Down