Skip to content

Commit

Permalink
[GPU] Fix compressed KV-cache shape infer for QKV order {1,2,0,3} (#2…
Browse files Browse the repository at this point in the history
…8592)

### Details:
- This change restores the original logic of axis calculation for
scales/zp and fixes the shape inference of KV-Cache operation.
Previously, it needed to be adjusted with concat_axis, but after the
introduction of independent macros for key and value scale/zp offsets
calculation in micro_sdpa kernel, this adjustment is no longer needed
and causes incorrect indexing. Therefore, this change reverts to the
original fixed scale/zp axis equal to 2.
  • Loading branch information
sshlyapn authored Jan 22, 2025
1 parent a6ab17e commit b8b2435
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) {
auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) {
auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx);

const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis();
padding::DynamicDimsMask info_dynamic_pad_scales;
info_dynamic_pad_scales[scales_zp_concat_axis] = 1;
scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales;
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache
return sequence_axis >= 0 ? sequence_axis : past_layout_rank + sequence_axis;
}

static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) {
const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis];
static int64_t get_scale_zp_sequence_axis() {
const auto scale_zp_concat_axis = 2;
return scale_zp_concat_axis;
}

Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto prealloc_shape = updated_layouts[i].get_shape();
const auto shape_rank = prealloc_shape.size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
: kv_cache_inst::get_scale_zp_sequence_axis();

prealloc_shape[seq_axis] += tmp_prealloc_count;
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
Expand Down Expand Up @@ -883,7 +883,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
const auto& desc = _node->as<kv_cache>().get_primitive();
const auto shape_rank = updated_layouts[i].get_shape().size();
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
: kv_cache_inst::get_scale_zp_sequence_axis();

prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
} else {
Expand All @@ -907,7 +907,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
auto& present_layout = _impl_params->output_layouts[i];
const auto present_layout_rank = present_layout.get_partial_shape().size();
const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank)
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
: kv_cache_inst::get_scale_zp_sequence_axis();

auto max_pad = kv_cache_inst::get_max_pad(present_layout,
_max_output_layout_count[i],
Expand Down Expand Up @@ -978,7 +978,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
if (max_pad > 0) {
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
auto present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();

// In case of compressed KV-cache, calling update_impl for each iteration
// because of scales layout [batch, num_heads, seq_len, head_size], which requires proper
Expand Down Expand Up @@ -1374,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
if (desc->compressed) {
auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable);
auto& present_scales_layout = _impl_params->output_layouts[2];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis);
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl;
Expand All @@ -1398,7 +1398,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {

if (desc->compressed) {
auto& past_scale_layout = _impl_params->input_layouts[3];
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis);

if (desc->get_compression_zp_inputs_num() > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ std::vector<ov::PartialShape> shape_infer(const KVCacheCompressed* op,
auto quantized_data_shapes =
ov::op::internal::DynamicQuantize::shape_infer(&dq_op, { input_shapes[1] });

const auto concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size());
const auto scales_concat_axis = op->get_quantization_attrs().scales_zp_output_order[concat_axis];
const auto scales_concat_axis = 2;
ov::PartialShape compression_scale_shape = input_shapes[3];
compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis];
out_shapes[2] = compression_scale_shape;
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ inline std::shared_ptr<ov::Node> make_qkv_transpose(ov::Output<ov::Node> qkv, st
return std::make_shared<ov::op::v1::Transpose>(qkv, transpose_const);
}

inline std::shared_ptr<ov::Node> make_kv_rearrange(ov::Output<ov::Node> kv_past, ov::Output<ov::Node> beam_idx) {
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
inline std::shared_ptr<ov::Node> make_kv_rearrange(ov::Output<ov::Node> kv_past, ov::Output<ov::Node> beam_idx, int axis_val = 0) {
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, axis_val);
return std::make_shared<ov::op::v8::Gather>(kv_past, beam_idx, axis, 0);
}

Expand Down Expand Up @@ -242,8 +242,8 @@ inline std::shared_ptr<ov::Model> make_llm_kv_cache_sdpa_pattern(ov::Dimension b
in_beam_idx->set_friendly_name("beam_idx");
params.push_back(in_beam_idx);

concat_k_input = make_kv_rearrange(past_k, in_beam_idx);
concat_v_input = make_kv_rearrange(past_v, in_beam_idx);
concat_k_input = make_kv_rearrange(past_k, in_beam_idx, qkv_order[0]);
concat_v_input = make_kv_rearrange(past_v, in_beam_idx, qkv_order[0]);
}

auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{concat_k_input, in_k_token}, concat_axis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});

// Beam search
p.push_back({with_rearrange, !with_mask, !with_scale, !causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
Expand All @@ -351,6 +352,7 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});

/* -- causal mask -- */

Expand All @@ -367,6 +369,8 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});

return p;
}

Expand Down

0 comments on commit b8b2435

Please sign in to comment.