Skip to content

Commit

Permalink
feat: fix the maximal grid dimension in prefill planning with CUDA gr…
Browse files Browse the repository at this point in the history
…aphs (#639)

Previously, differences in the contents of qo_indptr could lead to block
sizes varying across CUDA graph invocations, leading to illegal memory
accessed.

This PR alters the calculation of the block size to find a reasonable
maximum based on the longest sequence.

The maximum token count is fixed in `plan` on the `Python` side and
passed along to `scheduler.cuh` to derive the other parameters.

While this ensures correctness under CUDA graphs, when CUDA graphs are
enabled split-kv is now always used, potentially degrading performance
if CUDA graphs are to be used with fixed `qo_indptr`. However, for
varying `qo_indptr`, CUDA graphs deliver 4x performance improvements for
prefill on models such as Llama 3.2-1B.
  • Loading branch information
nandor authored Nov 25, 2024
1 parent 5fe9f7d commit 86ca89a
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 100 deletions.
115 changes: 70 additions & 45 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -419,21 +419,41 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
return cudaSuccess;
}

inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
if (avg_packed_qo_len > 64 && head_dim < 256) {
return 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
return 64;
} else {
// avg_packed_qo_len <= 16
return 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
return 64;
}
}
}

template <typename IdType>
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, uint32_t max_batch_size_if_split,
bool enable_cuda_graph) {
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t total_num_rows, uint32_t max_seq_len,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
uint32_t max_batch_size_if_split, bool enable_cuda_graph) {
std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
merge_indptr.push_back(0);
o_indptr.push_back(0);

const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
uint32_t total_num_rows = qo_indptr_h[batch_size];

// step 1: compute qo_chunk_size
// step 1: determine packed_qo_len_arr and verify qo_indptr contents.
std::vector<int64_t> packed_qo_len_arr(batch_size), kv_len_arr(batch_size);
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size);
if (packed_qo_len_arr[i] < 0) {
Expand All @@ -449,41 +469,43 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
<< kv_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
sum_packed_qo_len += packed_qo_len_arr[i];
}
int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;

// step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
uint32_t cta_tile_q;
if (avg_packed_qo_len > 64 && head_dim < 256) {
cta_tile_q = 128;
uint32_t total_num_tiles_q;
bool split_kv;
int64_t kv_chunk_size, new_batch_size;
if (enable_cuda_graph) {
// When CUDA graphs are enabled, the lengths of sequences determined by
// qo_indptr_h can vary. We assume that the dummy data based on which
// the CUDA graph is created fixes the maximum number of tokens.
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);

// Find an upper bound for the number of tiles, derived from the total
// number of rows and the batch size. The sum of qo lengths rounded
// up to cta_tile_q will not exceed this number derived from the total
// number of rows.
total_num_tiles_q = ceil_div(total_num_rows, cta_tile_q) + batch_size;

split_kv = true;
kv_chunk_size = max_batch_size_if_split;
new_batch_size = max_batch_size_if_split;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
cta_tile_q = 64;
} else {
// avg_packed_qo_len <= 16
cta_tile_q = 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
cta_tile_q = 64;
total_num_tiles_q = 0;
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
sum_packed_qo_len += packed_qo_len_arr[i];
}
}

uint32_t total_num_tiles_q = 0;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
total_num_tiles_q += ceil_div(packed_qo_len_arr[request_idx], cta_tile_q);
}
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);

// step 2: determine kv_chunk_size
auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize(
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));

if (enable_cuda_graph) {
split_kv = total_num_tiles_q < max_batch_size_if_split;
std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize(
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));
}

// step 3: split qo_indptr and kv_indptr
Expand Down Expand Up @@ -511,7 +533,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
kv_chunk_size *= page_size;

return std::make_tuple(split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
total_num_rows, std::move(request_indices), std::move(qo_tile_indices),
std::move(request_indices), std::move(qo_tile_indices),
std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr));
}

Expand Down Expand Up @@ -597,9 +619,10 @@ template <typename IdType>
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, void* page_locked_int_buffer,
size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
uint32_t max_seq_len, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
bool enable_cuda_graph, uint32_t sizeof_dtype_o,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
Expand All @@ -618,17 +641,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;

// step 2: determine kv_chunk_size
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows,
request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
o_indptr_vec] =
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size,
max_batch_size_if_split, enable_cuda_graph);
plan_info.cta_tile_q = cta_tile_q;
plan_info.total_num_rows = total_num_rows;

plan_info.enable_cuda_graph = enable_cuda_graph;
size_t padded_batch_size =
enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size;

plan_info.padded_batch_size = padded_batch_size;
plan_info.split_kv = split_kv;

Expand Down Expand Up @@ -679,6 +703,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr");
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset(
sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask");

IdType* merge_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset);
bool* block_valid_mask_h =
Expand Down
9 changes: 5 additions & 4 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ using namespace flashinfer;
std::vector<int64_t> BatchPrefillWithKVCachePlan(
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
unsigned int total_num_rows, unsigned int max_seq_len, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand All @@ -58,8 +59,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size,
enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
kv_indptr.data_ptr<IdType>(), total_num_rows, max_seq_len, batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);

TORCH_CHECK(status == cudaSuccess,
"Failed to plan prefill with error: ", cudaGetErrorString(status));
Expand Down
5 changes: 3 additions & 2 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at:
std::vector<int64_t> BatchPrefillWithKVCachePlan(
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
unsigned total_num_rows, unsigned int max_seq_len, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int64_t cuda_stream);

void BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,9 @@ def plan(
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
batch_size, # total_num_rows
batch_size,
1, # max_seq_len
num_qo_heads,
num_kv_heads,
page_size,
Expand Down
9 changes: 7 additions & 2 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
unsigned int total_num_rows,
unsigned int max_seq_len,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads,
Expand All @@ -156,8 +158,9 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{dtype_idx}}>(),
kv_indptr.data_ptr<{{dtype_idx}}>(), batch_size, num_qo_heads, num_kv_heads, {{head_dim}},
page_size, enable_cuda_graph, sizeof({{dtype_o}}), stream);
kv_indptr.data_ptr<{{dtype_idx}}>(), total_num_rows, max_seq_len,
batch_size, num_qo_heads, num_kv_heads, {{head_dim}}, page_size,
enable_cuda_graph, sizeof({{dtype_o}}), stream);
TORCH_CHECK(status == cudaSuccess,
"Failed to plan prefill with error: ", cudaGetErrorString(status));
Expand Down Expand Up @@ -457,6 +460,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr,
at::Tensor kv_indptr,
unsigned int total_num_rows,
unsigned int max_seq_len,
unsigned int batch_size,
unsigned int num_qo_heads,
unsigned int num_kv_heads,
Expand Down
67 changes: 59 additions & 8 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,8 @@ def __init__(
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
self._custom_mask_buf = custom_mask_buf
self._qk_indptr_buf = qk_indptr_buf
self._max_total_num_rows = None
self._max_seq_len = None

@property
def is_cuda_graph_enabled(self) -> bool:
Expand Down Expand Up @@ -993,7 +995,33 @@ def plan(
bitorder="little",
)

# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
paged_kv_indptr_host = paged_kv_indptr.to("cpu")

total_num_rows = qo_indptr_host[-1]
max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()

if self.is_cuda_graph_enabled:
if self._max_total_num_rows is None:
self._max_total_num_rows = total_num_rows
elif total_num_rows > self._max_total_num_rows:
raise ValueError(
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
"exceed the number of rows set during initialization {}.".format(
total_num_rows, self._max_total_num_rows
)
)
if self._max_seq_len is None:
self._max_seq_len = max_seq_len
elif max_seq_len > self._max_seq_len:
raise ValueError(
"The maximum sequence length in qo_indptr {} in cuda graph mode cannot "
"exceed the sequence length set during initialization {}.".format(
max_seq_len, self._max_seq_len
)
)

if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed during the lifecycle of the wrapper in "
Expand Down Expand Up @@ -1049,10 +1077,6 @@ def plan(
self.device, non_blocking=non_blocking
)

# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
paged_kv_indptr_host = paged_kv_indptr.to("cpu")

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
self._cached_module = get_batch_prefill_module(
Expand All @@ -1073,6 +1097,8 @@ def plan(
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
paged_kv_indptr_host,
total_num_rows,
max_seq_len,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down Expand Up @@ -1463,6 +1489,7 @@ def __init__(
self._kv_indptr_buf = kv_indptr_buf
self._custom_mask_buf = custom_mask_buf
self._qk_indptr_buf = qk_indptr_buf
self._max_total_num_rows = None

@property
def is_cuda_graph_enabled(self) -> bool:
Expand Down Expand Up @@ -1610,7 +1637,33 @@ def plan(
bitorder="little",
)

# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
paged_kv_indptr_host = paged_kv_indptr.to("cpu")

total_num_rows = qo_indptr_host[-1]
max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()

if self.is_cuda_graph_enabled:
if self._max_total_num_rows is None:
self._max_total_num_rows = total_num_rows
elif total_num_rows > self._max_total_num_rows:
raise ValueError(
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
"exceed the number of rows set during initialization {}.".format(
total_num_rows, self._max_total_num_rows
)
)
if self._max_seq_len is None:
self._max_seq_len = max_seq_len
elif max_seq_len > self._max_seq_len:
raise ValueError(
"The maximum sequence length in qo_indptr {} in cuda graph mode cannot "
"exceed the sequence length set during initialization {}.".format(
max_seq_len, self._max_seq_len
)
)

if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
Expand Down Expand Up @@ -1638,10 +1691,6 @@ def plan(
self._custom_mask_buf = packed_custom_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)

# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
qo_indptr_host = qo_indptr.to("cpu")
kv_indptr_host = kv_indptr.to("cpu")

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
self._cached_module = get_batch_prefill_module(
Expand All @@ -1662,6 +1711,8 @@ def plan(
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
kv_indptr_host,
total_num_rows,
max_seq_len,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
10 changes: 5 additions & 5 deletions src/bench_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);

handler.Plan<T, int32_t>((void*)thrust::raw_pointer_cast(float_buffer.data()),
float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()),
int_workspace_size_in_bytes, qo_indptr_h.data(), kv_indptr_host.data(),
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
handler.Plan<T, int32_t>(
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
qo_indptr_h.data(), kv_indptr_host.data(), /*total_num_rows=*/batch_size, /*max_seq_len=*/1,
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
Expand Down
Loading

0 comments on commit 86ca89a

Please sign in to comment.