diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp index 536b844417fab2..7fc1f10b9033b5 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -87,6 +87,7 @@ struct PagedAttentionManager { bool kv_cache_compression; ov::internal::CacheQuantMode key_cache_quant_mode; bool has_score_aggregation; + bool has_xattention; CacheRotationDescriptor rotation_config; std::vector subsequence_descs; @@ -134,6 +135,7 @@ struct PagedAttentionManager { bool kv_cache_compression, ov::internal::CacheQuantMode key_cache_quant_mode, bool has_score_aggregation, + bool has_xattention, CacheRotationDescriptor rotation_config, std::vector threshold) : num_heads(num_heads) @@ -145,6 +147,7 @@ struct PagedAttentionManager { , kv_cache_compression(kv_cache_compression) , key_cache_quant_mode(key_cache_quant_mode) , has_score_aggregation(has_score_aggregation) + , has_xattention(has_xattention) , rotation_config(rotation_config) , subsequence_descs(subsequence_descs) , test_engine(engine) @@ -232,55 +235,63 @@ struct PagedAttentionManager { } memory::ptr get_key_cache_memory_cm() { - auto key_cache_dt = data_types::f16; - auto adjusted_head_size = k_head_size; - if (kv_cache_compression) { - key_cache_dt = data_types::i8; - adjusted_head_size += 4; - } - - auto num_blocks = block_indices.back() + 1; - auto key_cache_shape = ov::PartialShape{num_blocks, num_kv_heads, block_size, adjusted_head_size}; + auto key_cache_dt = kv_cache_compression ? data_types::i8 : data_types::f16; + const int head_size = k_head_size; + const int adjusted_head_size = head_size + (kv_cache_compression ? 4 : 0); + + const auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{static_cast(num_blocks), + static_cast(num_kv_heads), + static_cast(block_size), + static_cast(adjusted_head_size)}; auto key_cache_layout = layout{key_cache_shape, key_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(key_cache_layout); for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = key_data[i].data() + input_token_offset * num_kv_heads * v_head_size + head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); - - // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] - size_t output_block_offset = - (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); - - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_kv_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + token_idx * v_head_size; + const int past_len = subsequence_descs[i].past_len; + if (past_len == 0) + continue; + + const int blocks_num = ceil_div(past_len + 1, block_size); + const int start_block_idx = block_indices[block_indices_begins[i]]; + + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + const int last_token_idx = (block_idx == blocks_num - 1) ? (past_len - block_size * block_idx) : block_size; + + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { + const size_t input_token_offset = static_cast(block_idx) * block_size + token_idx; + ov::float16* src_ptr = + key_data[i].data() + input_token_offset * static_cast(num_kv_heads) * head_size + static_cast(head_idx) * head_size; + + if (!kv_cache_compression) { + const size_t base = (static_cast(start_block_idx + block_idx) * num_kv_heads * block_size * head_size) + + (static_cast(head_idx) * block_size * head_size); + const size_t off = base + static_cast(token_idx) * head_size; + set_values(test_stream, memory, src_ptr, head_size, off); + } else { + auto [qdata, scale, zp] = quantize_data(src_ptr, head_size, false, true); + int8_t* qptr = reinterpret_cast(qdata.data()); - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } + const size_t block_stride_i8 = static_cast(adjusted_head_size) * block_size; + const size_t block_base_i8 = (static_cast(start_block_idx + block_idx) * num_kv_heads + head_idx) * block_stride_i8; + + const size_t data_off_i8 = block_base_i8 + token_idx * head_size; + set_values(test_stream, memory, qptr, head_size, data_off_i8); + + const size_t scale_base_i8 = block_base_i8 + head_size * block_size; + const size_t zp_base_i8 = scale_base_i8 + block_size * sizeof(ov::float16); + + const size_t scale_off_f16 = scale_base_i8 / 2 + token_idx; + const size_t zp_off_f16 = zp_base_i8 / 2 + token_idx; + + set_values(test_stream, memory, &scale, 1, scale_off_f16); + set_values(test_stream, memory, &zp, 1, zp_off_f16); } } } } } - return memory; } @@ -383,60 +394,68 @@ struct PagedAttentionManager { } memory::ptr get_value_cache_memory() { - auto value_cache_dt = data_types::f16; - auto adjusted_head_size = v_head_size; - if (kv_cache_compression) { - value_cache_dt = data_types::i8; - adjusted_head_size += 4; - } + auto value_cache_dt = kv_cache_compression ? data_types::i8 : data_types::f16; + const int head_size = v_head_size; - auto num_blocks = block_indices.back() + 1; - auto value_cache_shape = ov::PartialShape{ num_blocks, num_kv_heads, block_size, adjusted_head_size }; - auto value_cache_layout = layout{ value_cache_shape, value_cache_dt, format::bfyx }; + const int adjusted_head_size = head_size + (kv_cache_compression ? 4 : 0); + + const auto num_blocks = block_indices.back() + 1; + auto value_cache_shape = ov::PartialShape{static_cast(num_blocks), + static_cast(num_kv_heads), + static_cast(block_size), + static_cast(adjusted_head_size)}; + auto value_cache_layout = layout{value_cache_shape, value_cache_dt, format::bfyx}; auto memory = test_engine.allocate_memory(value_cache_layout); for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { - int past_len = subsequence_descs[i].past_len; - if (past_len != 0) { - int blocks_num = ceil_div(past_len + 1, block_size); - int start_block_idx = block_indices[block_indices_begins[i]]; - for (int block_idx = 0; block_idx < blocks_num; block_idx++) { - int last_token_idx = block_idx == blocks_num - 1 ? (past_len - block_size * block_idx) - : block_size; - for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { - for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { - size_t input_token_offset = block_idx * block_size + token_idx; - ov::float16* data_ptr = value_data[i].data() + - input_token_offset * num_kv_heads * v_head_size + - head_idx * v_head_size; - if (kv_cache_compression) { - auto [quantized_data, scale, zp] = quantize_data(data_ptr, v_head_size); - auto quantized_data_ptr = quantized_data.data(); + const int past_len = subsequence_descs[i].past_len; + if (past_len == 0) + continue; - // shape: [num_blocks, num_kv_heads, block_size, adjusted_head_size] - size_t output_block_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * adjusted_head_size + - head_idx * block_size * adjusted_head_size; - size_t output_offset = output_block_offset + - token_idx * v_head_size; - set_values(test_stream, memory, quantized_data_ptr, v_head_size, output_offset); + const int blocks_num = ceil_div(past_len + 1, block_size); + const int start_block_idx = block_indices[block_indices_begins[i]]; - size_t comp_offset = (output_block_offset + v_head_size * block_size) / 2; - set_values(test_stream, memory, &scale, 1, comp_offset + token_idx); - set_values(test_stream, memory, &zp, 1, comp_offset + block_size + token_idx); - } else { - // shape: [num_blocks, num_kv_heads, block_size, v_head_size] - size_t output_offset = (start_block_idx + block_idx) * num_kv_heads * block_size * v_head_size + - head_idx * block_size * v_head_size + - token_idx * v_head_size; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + const int last_token_idx = (block_idx == blocks_num - 1) ? (past_len - block_size * block_idx) : block_size; - set_values(test_stream, memory, data_ptr, v_head_size, output_offset); - } + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_kv_heads; head_idx++) { + const size_t input_token_offset = static_cast(block_idx) * block_size + token_idx; + + ov::float16* src_ptr = value_data[i].data() + input_token_offset * static_cast(num_kv_heads) * head_size + + static_cast(head_idx) * head_size; + + if (!kv_cache_compression) { + const size_t base = (static_cast(start_block_idx + block_idx) * static_cast(num_kv_heads) * + static_cast(block_size) * static_cast(head_size)) + + (static_cast(head_idx) * static_cast(block_size) * static_cast(head_size)); + const size_t off = base + static_cast(token_idx) * static_cast(head_size); + set_values(test_stream, memory, src_ptr, head_size, off); + } else { + auto [qdata, scale, zp] = quantize_data(src_ptr, head_size, false, has_xattention); + int8_t* qptr = reinterpret_cast(qdata.data()); + + const size_t block_stride_i8 = static_cast(adjusted_head_size) * static_cast(block_size); + const size_t block_base_i8 = + (static_cast(start_block_idx + block_idx) * static_cast(num_kv_heads) + static_cast(head_idx)) * + block_stride_i8; + + const size_t data_off_i8 = block_base_i8 + static_cast(token_idx) * static_cast(head_size); + set_values(test_stream, memory, qptr, head_size, data_off_i8); + + const size_t scale_base_i8 = block_base_i8 + static_cast(head_size) * static_cast(block_size); + const size_t zp_base_i8 = scale_base_i8 + static_cast(block_size) * sizeof(ov::float16); + + const size_t scale_off_f16 = (scale_base_i8 >> 1) + static_cast(token_idx); + const size_t zp_off_f16 = (zp_base_i8 >> 1) + static_cast(token_idx); + + set_values(test_stream, memory, &scale, 1, scale_off_f16); + set_values(test_stream, memory, &zp, 1, zp_off_f16); } } } } } - return memory; } @@ -648,47 +667,76 @@ struct PagedAttentionManager { return data; } - static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, size_t size, bool expand_range = false) { + static std::tuple, ov::float16, ov::float16> quantize_data(ov::float16* data, + size_t size, + bool expand_range = false, + bool has_xattention = false) { float min_value = std::numeric_limits::max(); float max_value = std::numeric_limits::lowest(); for (size_t i = 0; i < size; i++) { - min_value = std::min((float)(data[i]), min_value); - max_value = std::max((float)(data[i]), max_value); + float v = static_cast(data[i]); + min_value = std::min(min_value, v); + max_value = std::max(max_value, v); } - float diff_value = 0.001; + if (has_xattention) { + if (max_value == min_value) { + std::vector qdata(size, 0); + return {qdata, ov::float16(0.0f), ov::float16(min_value)}; + } + + float diff_value = max_value - min_value; + if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { + diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); + } + + float scale_val = 255.0f / diff_value; + float zp_val = -min_value * scale_val; + + std::vector qdata(size); + for (size_t i = 0; i < size; i++) { + float q = data[i] * scale_val + zp_val; + int v = static_cast(std::nearbyint(q)); + if (v < 0) + v = 0; + if (v > 255) + v = 255; + qdata[i] = static_cast(v); + } + + ov::float16 scale = static_cast(diff_value / 255.0f); + ov::float16 zp = static_cast(zp_val); + return {qdata, scale, zp}; + } + + float diff_value = 0.001f; if (max_value != min_value) diff_value = max_value - min_value; if (expand_range && std::abs(diff_value) <= std::abs(max_value) * 0.1f) { - // compensate too small range diff_value = (max_value - min_value) + std::max(1.0f, max_value * 0.1f); } - float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; - float zp = ((float)-min_value * scale) + std::numeric_limits::lowest(); - std::vector quantized_data; - quantized_data.resize(size); + float scale = (std::numeric_limits::max() - std::numeric_limits::lowest()) / diff_value; + float zp = -min_value * scale + std::numeric_limits::lowest(); + std::vector qdata(size); auto convert_char_rte = [](float val) { float rounded = std::nearbyint(val); - - if (rounded > 127.0f) { + if (rounded > 127.0f) return static_cast(127); - } else if (rounded < -128.0f) { + if (rounded < -128.0f) return static_cast(-128); - } else { - return static_cast(rounded); - } + return static_cast(rounded); }; for (size_t i = 0; i < size; i++) { - quantized_data[i] = convert_char_rte(data[i] * scale + zp); + qdata[i] = convert_char_rte(data[i] * scale + zp); } - scale = 1.0f / scale; - - return std::make_tuple(quantized_data, scale, zp); + ov::float16 scale_out = static_cast(1.0f / scale); + ov::float16 zp_out = static_cast(zp); + return {qdata, scale_out, zp_out}; } }; @@ -1168,6 +1216,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam { p.kv_cache_compression, p.key_cache_quant_mode, p.scores_mode == ScoresMode::SNAPKV, + p.has_xattention, p.rotation_config, p.threshold); @@ -1652,4 +1701,16 @@ INSTANTIATE_TEST_SUITE_P(smoke_cm_xattention, xattention_test, ::testing::Values paged_attention_test_params{ {{1, 128}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 129}}, 2, 2, 64, 64, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token paged_attention_test_params{ {{1, 32}}, 28, 28, 128, 128, 256, {0.9}, 0, true, DISABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + + paged_attention_test_params{ {{32, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{32, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + paged_attention_test_params{ {{2048, 0}}, 4, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, ENABLE_FA_V2 }, // 1st token + + paged_attention_test_params{ {{1, 31}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 32}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd toke + paged_attention_test_params{ {{1, 1023}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token + paged_attention_test_params{ {{1, 1024}}, 2, 2, 64, 64, 256, {0.9}, 0, true, ENABLE_CACHE_COMPRESSION, ov::internal::CacheQuantMode::BY_TOKEN, STATIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION, DISABLE_FA_V2 }, // 2nd token }));