diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 7db9c2c0d59419..03e4af4d16359b 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -943,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) { auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) { auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx); - const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(); padding::DynamicDimsMask info_dynamic_pad_scales; info_dynamic_pad_scales[scales_zp_concat_axis] = 1; scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales; diff --git a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h index e95e2e94ff4ab0..945894af30170c 100644 --- a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h @@ -62,8 +62,8 @@ class typed_primitive_inst : public typed_primitive_inst_base= 0 ? sequence_axis : past_layout_rank + sequence_axis; } - static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) { - const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis]; + static int64_t get_scale_zp_sequence_axis() { + const auto scale_zp_concat_axis = 2; return scale_zp_concat_axis; } diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index abfeabe2b6a149..e574684e6b4f10 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -851,7 +851,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { auto prealloc_shape = updated_layouts[i].get_shape(); const auto shape_rank = prealloc_shape.size(); const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank) - : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + : kv_cache_inst::get_scale_zp_sequence_axis(); prealloc_shape[seq_axis] += tmp_prealloc_count; required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies()); @@ -883,7 +883,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { const auto& desc = _node->as().get_primitive(); const auto shape_rank = updated_layouts[i].get_shape().size(); const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank) - : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + : kv_cache_inst::get_scale_zp_sequence_axis(); prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis); } else { @@ -907,7 +907,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { auto& present_layout = _impl_params->output_layouts[i]; const auto present_layout_rank = present_layout.get_partial_shape().size(); const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank) - : kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + : kv_cache_inst::get_scale_zp_sequence_axis(); auto max_pad = kv_cache_inst::get_max_pad(present_layout, _max_output_layout_count[i], @@ -978,7 +978,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { if (max_pad > 0) { if (auto compressed_cache_variable = dynamic_cast(&variable)) { auto present_scales_layout = _impl_params->output_layouts[2]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(); // In case of compressed KV-cache, calling update_impl for each iteration // because of scales layout [batch, num_heads, seq_len, head_size], which requires proper @@ -1374,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { if (desc->compressed) { auto compressed_cache_variable = dynamic_cast(&variable); auto& present_scales_layout = _impl_params->output_layouts[2]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(); kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis); GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl; @@ -1398,7 +1398,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { if (desc->compressed) { auto& past_scale_layout = _impl_params->input_layouts[3]; - const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes); + const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(); kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis); if (desc->get_compression_zp_inputs_num() > 0) { diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp index 908732dc357222..6721d0f9ebd608 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp @@ -191,8 +191,7 @@ std::vector shape_infer(const KVCacheCompressed* op, auto quantized_data_shapes = ov::op::internal::DynamicQuantize::shape_infer(&dq_op, { input_shapes[1] }); - const auto concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size()); - const auto scales_concat_axis = op->get_quantization_attrs().scales_zp_output_order[concat_axis]; + const auto scales_concat_axis = 2; ov::PartialShape compression_scale_shape = input_shapes[3]; compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis]; out_shapes[2] = compression_scale_shape; diff --git a/src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp b/src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp index c774049fe0690f..65221107967bcc 100644 --- a/src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp +++ b/src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp @@ -120,8 +120,8 @@ inline std::shared_ptr make_qkv_transpose(ov::Output qkv, st return std::make_shared(qkv, transpose_const); } -inline std::shared_ptr make_kv_rearrange(ov::Output kv_past, ov::Output beam_idx) { - auto axis = std::make_shared(ov::element::i32, ov::Shape{}, 0); +inline std::shared_ptr make_kv_rearrange(ov::Output kv_past, ov::Output beam_idx, int axis_val = 0) { + auto axis = std::make_shared(ov::element::i32, ov::Shape{}, axis_val); return std::make_shared(kv_past, beam_idx, axis, 0); } @@ -242,8 +242,8 @@ inline std::shared_ptr make_llm_kv_cache_sdpa_pattern(ov::Dimension b in_beam_idx->set_friendly_name("beam_idx"); params.push_back(in_beam_idx); - concat_k_input = make_kv_rearrange(past_k, in_beam_idx); - concat_v_input = make_kv_rearrange(past_v, in_beam_idx); + concat_k_input = make_kv_rearrange(past_k, in_beam_idx, qkv_order[0]); + concat_v_input = make_kv_rearrange(past_v, in_beam_idx, qkv_order[0]); } auto concat_k = std::make_shared(ov::OutputVector{concat_k_input, in_k_token}, concat_axis); diff --git a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp index 89612039fb788f..7bb4a7385bcdc4 100644 --- a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp +++ b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp @@ -342,6 +342,7 @@ std::vector get_test_params() { p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); + p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}}); // Beam search p.push_back({with_rearrange, !with_mask, !with_scale, !causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); @@ -351,6 +352,7 @@ std::vector get_test_params() { p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}}); /* -- causal mask -- */ @@ -367,6 +369,8 @@ std::vector get_test_params() { p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}}); + return p; }