Skip to content

Commit

Permalink
[Kernel][Attention] Separate Attention.kv_scale into k_scale and …
Browse files Browse the repository at this point in the history
…`v_scale` (#6081)
  • Loading branch information
mgoin authored Jul 16, 2024
1 parent 160e1d8 commit 978aed5
Show file tree
Hide file tree
Showing 33 changed files with 317 additions and 185 deletions.
8 changes: 5 additions & 3 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
kv_scale = 1.0
k_scale = v_scale = 1.0

for _ in range(num_iters):
if version == "v1":
Expand All @@ -117,7 +117,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -136,7 +137,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
k_scale,
v_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down
53 changes: 27 additions & 26 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
Expand Down Expand Up @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
k_vec_quant, k_scale);
}
}

Expand Down Expand Up @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
Expand Down Expand Up @@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
Expand Down Expand Up @@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
Expand Down Expand Up @@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand All @@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand Down Expand Up @@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

Expand Down Expand Up @@ -815,8 +815,8 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
Expand All @@ -833,7 +833,7 @@ void paged_attention_v1(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
Expand All @@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand Down Expand Up @@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
Expand Down Expand Up @@ -980,8 +981,8 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
Expand Down
4 changes: 2 additions & 2 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double kv_scale);
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale);

void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
Expand Down
21 changes: 11 additions & 10 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float kv_scale) {
const int head_size, const int block_size, const int x, const float k_scale,
const float v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
Expand Down Expand Up @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
}
}
}
Expand Down Expand Up @@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale);
num_heads, head_size, block_size, x, k_scale, v_scale);

void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
Expand All @@ -258,7 +258,8 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const double kv_scale) {
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
Expand Down Expand Up @@ -318,13 +319,13 @@ namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const float kv_scale,
const float scale,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
}
}

Expand All @@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);

// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double kv_scale, const std::string& kv_cache_dtype) {
const double scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
Expand Down
12 changes: 6 additions & 6 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,11 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
Expand Down Expand Up @@ -742,11 +742,11 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
Expand Down
5 changes: 3 additions & 2 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
const std::string& kv_cache_dtype, double k_scale,
double v_scale) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);

int num_tokens = key.size(0);
int num_heads = key.size(1);
Expand Down
10 changes: 5 additions & 5 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
Expand All @@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
Expand Down Expand Up @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale) -> ()");
" float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}

Expand Down
Loading

0 comments on commit 978aed5

Please sign in to comment.