From 98eaf1fee41abcf2765c4e35081c34a249f3e783 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 2 Dec 2024 18:02:18 +0000 Subject: [PATCH 1/6] add working kernel with padded_max_seq_len as arg Signed-off-by: NickLucche --- .../kernels/benchmark_paged_attention.py | 2 + csrc/attention/attention_kernels.cuh | 42 +++++++---- csrc/attention/paged_attention_v1.cu | 42 ++++++----- csrc/attention/paged_attention_v2.cu | 43 +++++++---- csrc/cpu/attention.cpp | 22 ++++-- csrc/cpu/torch_bindings.cpp | 11 +-- csrc/ops.h | 17 ++--- csrc/torch_bindings.cpp | 32 ++------ tests/kernels/test_attention.py | 75 ++++++++++++------- vllm/_custom_ops.py | 8 +- vllm/attention/backends/blocksparse_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 2 + vllm/attention/ops/paged_attn.py | 19 +++-- 14 files changed, 186 insertions(+), 131 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index daedaadb1a77b..55b0305dc1da2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -118,6 +118,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -138,6 +139,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf10..08f9882f65f09 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,8 +104,10 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, + const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -154,6 +156,14 @@ __device__ void paged_attention_kernel( const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence + // and current head. + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * padded_max_seq_len + + head_idx * padded_max_seq_len; + // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread // group fetch or compute 16 bytes at a time. For example, if the size of a @@ -285,7 +295,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, *k_scale); + k_vec_quant, k_scale); } } @@ -293,8 +303,10 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. + // Add the ALiBi bias if slopes are given, then add custom bias if given. + // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -415,7 +427,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - *v_scale); + v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -512,17 +524,19 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, + const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, + max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -548,17 +562,19 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, + const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, + padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b1014a..cbe5d1dd6f3f6 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -40,10 +40,10 @@ <<>>( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template & alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -73,15 +74,22 @@ void paged_attention_v1_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); - const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); - const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = @@ -137,8 +145,8 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \ + tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ @@ -178,10 +186,10 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02fb1..2b25a6afe765f 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -36,10 +36,11 @@ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ + attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -54,10 +55,11 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -74,7 +76,16 @@ void paged_attention_v2_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); @@ -84,11 +95,11 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); - const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); - const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -144,7 +155,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -189,10 +200,10 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index b9764056e8a2d..eb33c66953a6e 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -459,14 +459,17 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -781,14 +784,17 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5d1c5f4c83d3e..3cfa289848e21 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -24,13 +24,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. + // TODO attn_bias on cpu ops.def( "paged_attention_v1(" " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -43,8 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -148,7 +149,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef3188a3..f2ad92074f446 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,10 +33,10 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -45,10 +45,10 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -153,7 +153,6 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, #ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); -bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c03806f430a7c..5ab404523494f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -29,8 +29,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -43,8 +43,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," + " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -324,13 +324,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); - // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) - ops.def( - "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " - "bool"); - ops.impl("cutlass_scaled_mm_supports_block_fp8", - &cutlass_scaled_mm_supports_fp8); - // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( @@ -450,17 +443,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); - cache_ops.def( - "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()"); - cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla); - // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. @@ -470,19 +459,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); - // Concat kv_c and k_pe and cache them. - cache_ops.def( - "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," - " Tensor! kv_cache," - " Tensor slot_mapping," - " str kv_cache_dtype," - " Tensor scale) -> ()"); - cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); - // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b667d8d9e0307..380f868a29248 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -20,7 +20,8 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 16 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -31,6 +32,7 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in @@ -39,6 +41,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +USE_CUSTOM_ATTN_BIAS = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -62,16 +65,11 @@ def ref_masked_attention( def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: + output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, + key_cache: torch.Tensor, value_cache: torch.Tensor, + block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, + alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[List[torch.Tensor]]) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] @@ -104,15 +102,19 @@ def ref_single_query_cached_kv_attention( keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - alibi_bias = None + bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) + bias = alibi_bias + if attn_bias is not None: + # TODO test alibi + bias + bias = attn_bias[i] if bias is None else bias + attn_bias[i] + # print(f"ATTN BIAS {i}: {attn_bias[i]}") + out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -124,6 +126,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @@ -136,12 +139,17 @@ def test_paged_attention( num_heads: Tuple[int, int], head_size: int, use_alibi: bool, + use_custom_attn_bias: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, seed: int, device: str, ) -> None: + # num_heads = (2, 2) + # num_seqs = 2 + # head_size = 32 + if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -155,7 +163,7 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None + alibi_slopes, attn_bias = None, None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) @@ -163,6 +171,20 @@ def test_paged_attention( seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) + attn_bias_list = None + if use_custom_attn_bias: + # NOTE (NickLucche) each sequence can have a different bias, + # depending on its len, but it *must* be float (f32)! + attn_bias_list = [torch.randn(num_query_heads, + 1, + seq_len, + dtype=torch.float) for seq_len in seq_lens] + attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + + for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): + # first seq_len entries of the bias for each head/seq + attn_bias[i, :, :, :seq_len] = bias + # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -188,6 +210,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) + # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -201,19 +224,23 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, ) + # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -242,6 +269,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -251,6 +279,7 @@ def test_paged_attention( (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -307,17 +336,11 @@ def test_paged_attention( value_cache = dequantized_value_cache ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - ) + ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv, + key_cache, value_cache, block_tables, + seq_lens, scale, alibi_slopes, + attn_bias_list) + # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a682350167675..98d7c9f875cab 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -49,6 +49,7 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -60,8 +61,8 @@ def paged_attention_v1( ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, + seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -81,6 +82,7 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -93,7 +95,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9765e7881ad9d..ee58696b18918 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -443,6 +443,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b7c..3b1944f47e6ab 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -837,6 +837,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b35..9eb8087e42f40 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -583,6 +583,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, + prefill_meta.attn_bias, self.sliding_window, layer._k_scale, layer._v_scale, @@ -611,6 +612,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + decode_meta.attn_bias, layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38d66..bc651fa0dc326 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,5 +1,3 @@ -# SPDX-License-Identifier: Apache-2.0 - from dataclasses import dataclass from typing import List, Optional, Tuple @@ -71,8 +69,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, + k_scale: float, + v_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -97,8 +95,9 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: torch.Tensor, - v_scale: torch.Tensor, + attn_bias: Optional[torch.Tensor], + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -142,6 +141,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -180,6 +180,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -205,11 +206,13 @@ def forward_prefix( context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], - k_scale: torch.Tensor, - v_scale: torch.Tensor, + k_scale: float, + v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) + assert attn_bias is None, "Bias for prefix not yet enabled" context_attention_fwd( query, key, From 752be56e7718ca99b84975bc72ae0c8bf6b054c1 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:22:50 +0000 Subject: [PATCH 2/6] add attn_bias case to pagedattn tests Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 51 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 380f868a29248..cd59bd20e40df 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -20,8 +20,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -MAX_SEQ_LEN = 16 +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -32,7 +31,6 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in @@ -111,9 +109,7 @@ def ref_single_query_cached_kv_attention( 1, 1, -1) bias = alibi_bias if attn_bias is not None: - # TODO test alibi + bias bias = attn_bias[i] if bias is None else bias + attn_bias[i] - # print(f"ATTN BIAS {i}: {attn_bias[i]}") out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -146,10 +142,6 @@ def test_paged_attention( seed: int, device: str, ) -> None: - # num_heads = (2, 2) - # num_seqs = 2 - # head_size = 32 - if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -175,16 +167,20 @@ def test_paged_attention( if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, # depending on its len, but it *must* be float (f32)! - attn_bias_list = [torch.randn(num_query_heads, - 1, - seq_len, - dtype=torch.float) for seq_len in seq_lens] - attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + attn_bias_list = [ + torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) + for seq_len in seq_lens + ] + attn_bias = torch.empty(num_seqs, + num_query_heads, + 1, + max_seq_len, + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): - # first seq_len entries of the bias for each head/seq + # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias - # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -210,7 +206,6 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -229,18 +224,15 @@ def test_paged_attention( k_scale, v_scale, ) - # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): - assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -275,14 +267,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -340,7 +332,6 @@ def test_paged_attention( key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes, attn_bias_list) - # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two From 8703add8825644891ad60b7cb4110b29ccd72b17 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:32:29 +0000 Subject: [PATCH 3/6] format Signed-off-by: NickLucche --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- vllm/attention/backends/xformers.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 55b0305dc1da2..98645a0cd4cc0 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -118,7 +118,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, - None, # TODO add custom bias + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 9eb8087e42f40..8cb703fc498c0 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -612,6 +612,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + # TODO (NickLucche) cross_attn_bias not needed for T5-like + # models, abstract bias selection if needed. decode_meta.attn_bias, layer._k_scale, layer._v_scale, From 0ef1470afa6453e13c4d7688c2f916751b2932d7 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 27 Dec 2024 17:21:43 +0000 Subject: [PATCH 4/6] enforce last dim of attn bias to be block aligned Signed-off-by: NickLucche --- csrc/attention/paged_attention_v1.cu | 17 ++++++++++------- csrc/attention/paged_attention_v2.cu | 21 ++++++++++++--------- tests/kernels/test_attention.py | 19 +++++++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index cbe5d1dd6f3f6..0b04b55a4e13a 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -77,12 +77,17 @@ void paged_attention_v1_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -92,13 +97,11 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 2b25a6afe765f..5eeba75d5cf1c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -79,12 +79,17 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -97,18 +102,16 @@ void paged_attention_v2_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index cd59bd20e40df..314db5db12139 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -164,26 +164,29 @@ def test_paged_attention( max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, - # depending on its len, but it *must* be float (f32)! + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! attn_bias_list = [ torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) for seq_len in seq_lens ] - attn_bias = torch.empty(num_seqs, - num_query_heads, - 1, - max_seq_len, - device=device, - dtype=torch.float) + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ From 6228f76c7b6cfdc6270b89bfc706a39487cff0ff Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 15 Jan 2025 14:28:52 +0000 Subject: [PATCH 5/6] fix blocksparse tests Signed-off-by: NickLucche --- tests/kernels/test_blocksparse_attention.py | 2 ++ vllm/attention/ops/ipex_attn.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index e653d34d00ee1..bedff278f95f2 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -230,6 +230,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -267,6 +268,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 598ceea130d97..150c99fc97ff8 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -105,6 +105,7 @@ def forward_decode( block_size, max_context_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, From c7c983d5e18c758ce09dc5c12a77d830142522eb Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 17 Feb 2025 16:04:01 +0000 Subject: [PATCH 6/6] fix rebase Signed-off-by: NickLucche --- csrc/attention/attention_kernels.cuh | 10 +++--- csrc/attention/paged_attention_v1.cu | 49 +++++++++++++++------------- csrc/attention/paged_attention_v2.cu | 27 ++++++++------- csrc/cpu/attention.cpp | 16 ++++----- csrc/cpu/torch_bindings.cpp | 6 ++-- csrc/ops.h | 11 ++++--- csrc/torch_bindings.cpp | 28 +++++++++++++--- vllm/attention/ops/paged_attn.py | 13 ++++---- 8 files changed, 95 insertions(+), 65 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 08f9882f65f09..d19771a19cc0a 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -107,7 +107,7 @@ __device__ void paged_attention_kernel( const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -295,7 +295,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); + k_vec_quant, *k_scale); } } @@ -427,7 +427,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); + *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -527,7 +527,7 @@ __global__ void paged_attention_v1_kernel( const float* __restrict__ attn_bias, const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. @@ -53,11 +53,11 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, - const c10::optional& attn_bias, float k_scale, float v_scale, - const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, + const std::optional& attn_bias, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -95,6 +95,8 @@ void paged_attention_v1_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int logits_size = padded_max_seq_len * sizeof(float); @@ -189,10 +191,11 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::optional& alibi_slopes, + const std::optional& attn_bias, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 5eeba75d5cf1c..ccfd6cd60f55c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -38,9 +38,9 @@ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); \ + kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -55,11 +55,11 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, - const c10::optional& attn_bias, float k_scale, float v_scale, - const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, + const std::optional& attn_bias, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -100,6 +100,8 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); @@ -203,10 +205,11 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const c10::optional& alibi_slopes, - const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::optional& alibi_slopes, + const std::optional& attn_bias, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index eb33c66953a6e..4c67a775a0bd8 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -461,11 +461,11 @@ void paged_attention_v1( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); TORCH_CHECK(!attn_bias.has_value(), @@ -784,13 +784,13 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const c10::optional& alibi_slopes, - const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::optional& attn_bias, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); TORCH_CHECK(!attn_bias.has_value(), diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 3cfa289848e21..9ddb837d18200 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -31,7 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -45,7 +45,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -149,7 +149,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index f2ad92074f446..6a467fd6f378b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -35,8 +35,9 @@ void paged_attention_v1( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -47,8 +48,9 @@ void paged_attention_v2( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const c10::optional& attn_bias, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -153,6 +155,7 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, #ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5ab404523494f..29e91762c74bc 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -324,6 +324,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) + ops.def( + "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " + "bool"); + ops.impl("cutlass_scaled_mm_supports_block_fp8", + &cutlass_scaled_mm_supports_fp8); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( @@ -443,13 +450,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); + cache_ops.def( + "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()"); + cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla); + // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. @@ -459,10 +470,19 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + // Concat kv_c and k_pe and cache them. + cache_ops.def( + "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor scale) -> ()"); + cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index bc651fa0dc326..89ae554bff987 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass from typing import List, Optional, Tuple @@ -69,8 +70,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: ops.reshape_and_cache( key, @@ -96,8 +97,8 @@ def forward_decode( scale: float, alibi_slopes: Optional[torch.Tensor], attn_bias: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -208,8 +209,8 @@ def forward_prefix( alibi_slopes: Optional[torch.Tensor], attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) assert attn_bias is None, "Bias for prefix not yet enabled"