From 978aed53004b82877bd2af0f10afff1826d7194d Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 16 Jul 2024 18:31:32 -0400 Subject: [PATCH] [Kernel][Attention] Separate `Attention.kv_scale` into `k_scale` and `v_scale` (#6081) --- .../kernels/benchmark_paged_attention.py | 8 ++- csrc/attention/attention_kernels.cu | 53 ++++++++--------- csrc/cache.h | 4 +- csrc/cache_kernels.cu | 21 +++---- csrc/cpu/attention.cpp | 12 ++-- csrc/cpu/cache.cpp | 5 +- csrc/cpu/torch_bindings.cpp | 10 ++-- csrc/ops.h | 8 +-- csrc/torch_bindings.cpp | 10 ++-- tests/kernels/test_attention.py | 8 ++- tests/kernels/test_blocksparse_attention.py | 8 ++- tests/kernels/test_cache.py | 4 +- tests/quantization/test_fp8.py | 40 +++++++++++-- vllm/_custom_ops.py | 18 +++--- vllm/_ipex_ops.py | 9 ++- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/blocksparse_attn.py | 9 ++- vllm/attention/backends/flash_attn.py | 6 +- vllm/attention/backends/flashinfer.py | 6 +- vllm/attention/backends/ipex_attn.py | 14 +++-- vllm/attention/backends/pallas.py | 5 +- vllm/attention/backends/rocm_flash_attn.py | 9 ++- vllm/attention/backends/torch_sdpa.py | 11 ++-- vllm/attention/backends/xformers.py | 8 ++- vllm/attention/layer.py | 14 +++-- vllm/attention/ops/ipex_attn.py | 6 +- vllm/attention/ops/paged_attn.py | 15 +++-- vllm/model_executor/layers/linear.py | 9 +++ .../model_executor/layers/quantization/fp8.py | 51 +++++++++++----- .../model_loader/weight_utils.py | 58 +++++++++++++++++-- vllm/model_executor/models/llama.py | 19 ++---- vllm/model_executor/models/mixtral.py | 21 ++----- vllm/model_executor/models/qwen2.py | 20 ++----- 33 files changed, 317 insertions(+), 185 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 16de60477c305..78cac8a555d1b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -100,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 for _ in range(num_iters): if version == "v1": @@ -117,7 +117,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -136,7 +137,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 91083481705cb..350dbce1d7ba9 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -105,9 +105,9 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, kv_scale); + k_vec_quant, k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - kv_scale); + v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale, tp_rank, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel( out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale, tp_rank, blocksparse_local_blocks, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -694,8 +694,8 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -770,7 +770,7 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); @@ -815,8 +815,8 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); @@ -833,7 +833,7 @@ void paged_attention_v1( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -932,8 +932,9 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ switch (is_block_sparse) { \ @@ -980,8 +981,8 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 86caa9345361d..52177e8901a89 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,8 +18,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double kv_scale); + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 72041076ae009..caef7f5e18630 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, - const float kv_scale) { + const int head_size, const int block_size, const int x, const float k_scale, + const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, kv_scale); + fp8::scaled_convert(tgt_key, k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, kv_scale); + fp8::scaled_convert(tgt_value, v_scale); } } } @@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, kv_scale); + num_heads, head_size, block_size, x, k_scale, v_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -258,7 +258,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double kv_scale) { + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -318,13 +319,13 @@ namespace vllm { template __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, - const float kv_scale, + const float scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; dst_cache[idx] = - fp8::scaled_convert(src_cache[idx], kv_scale); + fp8::scaled_convert(src_cache[idx], scale); } } @@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ vllm::convert_fp8_kernel<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const double kv_scale, const std::string& kv_cache_dtype) { + const double scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..abb4e3bea14bb 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -423,11 +423,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -742,11 +742,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2b5c3bd6ee70b..31d454328b2c1 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,8 +107,9 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double kv_scale) { - TORCH_CHECK(kv_scale == 1.0f); + const std::string& kv_cache_dtype, double k_scale, + double v_scale) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); int num_tokens = key.size(0); int num_heads = key.size(1); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 39e8cf3ed3c10..5be0e9810b5b9 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); @@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index fb1099e4fe0c2..f9feb3deff5e4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -8,8 +8,8 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -19,8 +19,8 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 18331a674eeba..9dc7cefc404ca 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); @@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); @@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index f848ad51c7014..2e6412c28958e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -175,7 +175,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -193,7 +193,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -224,7 +225,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 402545d1980d6..b3adb152949a2 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -212,7 +212,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 tp_rank = 0 # Call the paged attention kernel. @@ -231,7 +231,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, @@ -267,7 +268,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23b6baa60c05b..70ae3d0c6e0c3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -155,11 +155,11 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 # Call the reshape_and_cache kernel. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, k_scale, v_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0ed91cbb447fd..82dc775f8d812 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -7,19 +7,49 @@ from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, + Fp8LinearMethod) MODELS = [ - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "nm-testing/Phi-3-mini-128k-instruct-FP8", ] @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") -@pytest.mark.parametrize("model", MODELS) -def test_model_load_and_run(vllm_runner, model: str): - with vllm_runner(model) as llm: +@pytest.mark.parametrize("model_id", MODELS) +def test_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id) as llm: + # note: this does not test accuracy, just that we can run through + # see lm-eval tests for accuracy + outputs = llm.generate_greedy(prompts=["Hello my name is"], + max_tokens=10) + print(outputs[0][1]) + + +KV_CACHE_MODELS = [ + # Deprecated AutoFP8 format using .kv_scale + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + # AutoFP8 format using separate .k_scale and .v_scale + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", +] + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_id", KV_CACHE_MODELS) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + attn = model.model.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + # NOTE: it is valid for scales to be 1.0 (default value), but we know + # these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy outputs = llm.generate_greedy(prompts=["Hello my name is"], diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 03308d04012aa..4ca67224a91b8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -84,7 +84,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -94,8 +95,9 @@ def paged_attention_v1( torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) def paged_attention_v2( @@ -114,7 +116,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -124,7 +127,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -374,11 +377,12 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, k_scale, v_scale) def reshape_and_cache_flash( diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 99a875c9b3fb7..b4721b4e1aedd 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -59,7 +59,8 @@ def paged_attention_v1( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -99,7 +100,8 @@ def paged_attention_v2( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -227,7 +229,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adb8325168cdf..1310bb1679e15 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -134,7 +134,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index fe4c4a45dca0d..6308cf07ce41e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -327,7 +327,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -368,7 +369,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -405,7 +407,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 048abed48d2e9..0b6bd21279393 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -256,7 +256,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -277,7 +278,8 @@ def forward( "FlashAttentionImpl") # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b27e3e40f566d..a4b01c6d3b508 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -223,10 +223,12 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 6a1295b1000bc..4559dd15f600c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -156,7 +156,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -170,7 +171,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -192,7 +193,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) if attn_metadata.is_prompt: @@ -273,7 +275,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: # Run PagedAttention V2. @@ -305,7 +308,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c45f7b28b2afb..b83a83bb177d4 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -131,7 +131,8 @@ def forward( value: torch.Tensor, kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], attn_metadata: PallasMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -146,7 +147,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81b546c65c819..f6ecea30da492 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -296,7 +296,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -336,7 +337,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -456,7 +458,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 48418f24870f9..fe6a56123ce72 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -144,7 +144,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -158,7 +159,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -176,7 +177,8 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, k_scale, + v_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -239,7 +241,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6cc5f1d1477ae..3dd60ed5be528 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -427,7 +427,8 @@ def forward( value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -531,7 +532,7 @@ def forward( value_cache, updated_slot_mapping, self.kv_cache_dtype, - kv_scale) + k_scale, v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -620,7 +621,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b8cc87be8c748..0619bda90a2a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,13 +47,14 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads - # The default kv_scale is set to 1.0. This is ignored + # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized kv_scale to be loaded along + # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._kv_scale = 1.0 + self._k_scale = 1.0 + self._v_scale = 1.0 quant_method = quant_config.get_quant_method( self) if quant_config else None if quant_method is not None: @@ -66,8 +67,8 @@ def __init__( "fp8 checkpoints.") # When FP8 quantization is enabled, we make a parameter # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The kv_scale will then be converted back to self._kv_scale - # in a native float32 value after weight loading. + # The k/v_scale will then be converted back to + # self._kv_scale in a native float32 value after weight loading self.quant_method = quant_method self.quant_method.create_weights(self) @@ -98,7 +99,8 @@ def forward( value, kv_cache, attn_metadata, - self._kv_scale, + self._k_scale, + self._v_scale, attn_type=attn_type) def extra_repr(self) -> str: diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 5a5317b65004e..81d308c4d4e22 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -45,7 +45,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -64,7 +65,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + k_scale: float, + v_scale: float, *args, ) -> torch.Tensor: output = torch.empty_like(query) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index a214f40d16514..ce7b4d129779c 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -66,7 +66,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -75,7 +76,8 @@ def write_to_paged_cache( value_cache, slot_mapping.flatten(), kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) @staticmethod @@ -90,7 +92,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -135,7 +138,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -172,7 +176,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bc07d2b831862..684e1abf7bcf7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -196,6 +196,15 @@ def __init__(self, else: self.register_parameter("bias", None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5c916c9b4d7e4..cfef914ed6cf7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -407,31 +407,56 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. + """Create "weight" (aka k_scale and v_scale) for an attention layer. Args: layer: The layer that is using the QuantizeMethodBase factory. """ - # Initialize the KV cache scale to 1.0 as the default value. - # If the kv_scale appears in the checkpoint, it will be + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be # overwritten when loading weights. - layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") def process_weights_after_loading(self, layer: Module) -> None: - # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": - kv_scale = layer.kv_scale.to("cpu").tolist() - if not isinstance(kv_scale, float): + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = Parameter(torch.tensor(1.0), requires_grad=False) + v_scale = Parameter(torch.tensor(1.0), requires_grad=False) + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): raise ValueError("Only support per-tensor scaling factor " "for fp8 KV cache") - layer._kv_scale = kv_scale - if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + + # These are used in the final Attention.forward() + layer._k_scale = k_scale + layer._v_scale = v_scale + if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " - "cause accuracy issues. Please make sure kv-cache scaling " - "factor is available in the fp8 checkpoint.") - del layer.kv_scale + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c8568b3dc6690..cb83f43a2a4e2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" - # If the weight on disk does not have a shape, give it one - # (such scales for AutoFp8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) @@ -462,3 +458,55 @@ def initialize_dummy_weights( param.data.copy_(tmp_param) else: param.uniform_(low, high) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + print_warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded.") + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded.") + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a777d1fbfa802..f03e34b9e7c92 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,10 +44,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers @@ -460,18 +460,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0c456ada61230..e739df87cf96a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -42,10 +42,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -415,19 +415,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e9ae2192f280d..e9aa4416eded4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -43,10 +43,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -382,18 +382,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)