Skip to content

Commit

Permalink
misc: enhance allocator error info and add shape check for prefill be…
Browse files Browse the repository at this point in the history
…gin forward functions (#413)

This PR makes the following changes to the codebase:
1. make the allocators error information more informative, more
specifically, we print the buffer name and requested buffer size in
runtime errors for debugging.
2. add checks in prefill wrappers `begin_forward` functions to make sure
`qo` and `kv` indptr array size matches.

These efforts are designed for avoiding issues such as #362 , which
needs to be fixed on vllm side, but we should have more friendly
debugging information for locating the potential bugs.
  • Loading branch information
yzh119 authored Jul 31, 2024
1 parent 9907bc1 commit 5e36c52
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 56 deletions.
8 changes: 6 additions & 2 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define FLASHINFER_ALLOCATOR_H_

#include <memory>
#include <sstream>
#include <stdexcept>

namespace flashinfer {
Expand All @@ -26,14 +27,17 @@ struct AlignedAllocator {
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
T* aligned_alloc(size_t size, size_t alignment, std::string name) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
std::ostringstream oss;
oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment "
<< alignment << " in AlignedAllocator";
throw std::runtime_error(oss.str());
}
return nullptr;
}
Expand Down
114 changes: 67 additions & 47 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
high = mid;
}
}

new_batch_size = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) *
Expand Down Expand Up @@ -340,32 +339,37 @@ class BatchDecodeHandler {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<float>(num_qo_heads * padded_batch_size * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16,
"batch_decode_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(num_qo_heads * padded_batch_size * sizeof(float),
16, "batch_decode_tmp_s");
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16,
"batch_decode_new_indptr");

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
new_last_page_len_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_new_last_page_len");
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
chunk_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType),
16, "batch_decode_chunk_indptr");
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_batch_idx_map");
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_chunk_start_pos");
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
seq_lengths_before_partition_ = allocator.aligned_alloc<void>(
padded_batch_size * sizeof(IdType), 16, "batch_decode_seq_lengths_before_partition");
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16,
"batch_decode_block_valid_mask");
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);
Expand All @@ -390,30 +394,32 @@ class BatchDecodeHandler {
if (split_kv) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<float>(num_qo_heads * new_batch_size * sizeof(float), 16);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16,
"batch_decode_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(num_qo_heads * new_batch_size * sizeof(float), 16,
"batch_decode_tmp_s");
new_indptr_ = allocator.aligned_alloc<void>(
(batch_size_after_partition_ + 1) * sizeof(IdType), 16, "batch_decode_new_indptr");
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
new_last_page_len_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_new_last_page_len");
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>(
(batch_size_before_partition_ + 1) * sizeof(IdType), 16);
(batch_size_before_partition_ + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr");
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
batch_idx_map_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_batch_idx_map");
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
chunk_start_pos_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_chunk_start_pos");
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16,
"batch_decode_seq_lengths_before_partition");
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
Expand Down Expand Up @@ -678,27 +684,34 @@ class BatchPrefillHandler {
if (IsCUDAGraphEnabled()) {
padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q);
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_request_indices");
void* request_indices_h_ = page_locked_buffer_;
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_qo_tile_indices");
void* qo_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_kv_tile_indices");
void* kv_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (batch_size + 1), 16);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (batch_size + 1), 16,
"batch_prefill_o_indptr");
void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_);
kv_chunk_size_ptr_ = allocator.aligned_alloc<void>(sizeof(IdType), 1);
kv_chunk_size_ptr_ =
allocator.aligned_alloc<void>(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
void* kv_chunk_size_ptr_h_ =
(char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_);
*(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size;
if (total_num_tiles_q < split_max_batch_size) {
// need merge_indptr
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (total_num_rows_ + 1), 16);
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (total_num_rows_ + 1), 16,
"batch_prefill_merge_indptr");
void* merge_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_);
std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_);
block_valid_mask_ = allocator.aligned_alloc<bool>(sizeof(bool) * padded_batch_size_, 16);
block_valid_mask_ = allocator.aligned_alloc<bool>(sizeof(bool) * padded_batch_size_, 16,
"batch_prefill_block_valid_mask");
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_);
for (uint32_t i = 0; i < padded_batch_size_; ++i) {
Expand All @@ -724,37 +737,42 @@ class BatchPrefillHandler {

if (total_num_tiles_q < split_max_batch_size) {
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16);
num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16,
"batch_prefill_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(
num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16);
num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16,
"batch_prefill_tmp_s");
} else {
tmp_v_ = nullptr;
tmp_s_ = nullptr;
}
} else {
padded_batch_size_ = new_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * request_indices_vec.size(), 16);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * request_indices_vec.size(),
16, "batch_prefill_request_indices");
void* request_indices_h_ = page_locked_buffer_;
qo_tile_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * qo_tile_indices_vec.size(), 16);
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * qo_tile_indices_vec.size(),
16, "batch_prefill_qo_tile_indices");
void* qo_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_);
kv_tile_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * kv_tile_indices_vec.size(), 16);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * kv_tile_indices_vec.size(),
16, "batch_prefill_kv_tile_indices");
void* kv_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_);
if (split_kv) {
// need merge_indptr when split_kv is true
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * merge_indptr_vec.size(), 16);
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * merge_indptr_vec.size(), 16,
"batch_prefill_merge_indptr");
void* merge_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_);
std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_);
}
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * o_indptr_vec.size(), 16);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * o_indptr_vec.size(), 16,
"batch_prefill_o_indptr");
void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_);
kv_chunk_size_ptr_ = allocator.aligned_alloc<void>(sizeof(IdType), 1);
kv_chunk_size_ptr_ =
allocator.aligned_alloc<void>(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
void* kv_chunk_size_ptr_h_ =
(char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_);
*(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size;
Expand All @@ -772,9 +790,11 @@ class BatchPrefillHandler {

if (split_kv) {
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16);
num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16,
"batch_prefill_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(
num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16);
num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16,
"batch_prefill_tmp_s");
} else {
tmp_v_ = nullptr;
tmp_s_ = nullptr;
Expand Down
14 changes: 7 additions & 7 deletions include/flashinfer/group_gemm/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType*
AlignedAllocator allocator(handler->GetWorkspace(), handler->GetWorkspaceSizeInBytes());
cutlass::gemm::GemmCoord* problem_sizes_device =
allocator.aligned_alloc<cutlass::gemm::GemmCoord>(
batch_size * sizeof(cutlass::gemm::GemmCoord), 16);
DType** x_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
DType** w_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
DType** y_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
int64_t* ld_x = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
int64_t* ld_w = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
int64_t* ld_y = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device");
DType** x_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "x_data");
DType** w_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "w_data");
DType** y_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "y_data");
int64_t* ld_x = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_x");
int64_t* ld_w = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_w");
int64_t* ld_y = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_y");

// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
// so I just use the kernel function directly, need to investigate more.
Expand Down
6 changes: 6 additions & 0 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_CONTIGUOUS(paged_kv_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, paged_kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
auto device = workspace_buffer.device();
Expand Down Expand Up @@ -361,7 +364,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
CHECK_CONTIGUOUS(qo_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
Expand Down

0 comments on commit 5e36c52

Please sign in to comment.