diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 2ccd19ca3e1fc3..7df36e68b50ffd 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -18,6 +18,8 @@ #include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" #include "transformations/utils/utils.hpp" +#include "openvino/pass/visualize_tree.hpp" + using namespace ov::op; ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, @@ -90,12 +92,21 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrset_partial_shape(PartialShape{-1}); - auto input_ids_target_inputs = input_ids_node->get_output_target_inputs(0); - auto unsqueezed_input_ids = - std::make_shared(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1})); - for (const auto& target : input_ids_target_inputs) { - target.replace_source_output(unsqueezed_input_ids); + std::shared_ptr processed_input_ids; + if (input_ids_node->get_friendly_name() == "input_ids") { + auto input_ids_target_inputs = input_ids_node->get_output_target_inputs(0); + input_ids_node->set_partial_shape(PartialShape{-1}); + processed_input_ids = + std::make_shared(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1})); + for (const auto& target : input_ids_target_inputs) { + target.replace_source_output(processed_input_ids); + } + } else if (input_ids_node->get_friendly_name() == "inputs_embeds") { + // VLMs have the input_ids part + embeddings calculation + // served as "inputs_embeds" input, so there's no need + // for additional work on the input here as this is done + // for "input_ids" + processed_input_ids = input_ids_node; } ParameterVector kv_parameters; @@ -141,7 +152,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(unsqueezed_input_ids, max_context_len, position_ids); + manager.register_pass(processed_input_ids, max_context_len, position_ids); manager.register_pass(max_context_len); manager.register_pass(max_context_len); manager.register_pass(unsqueezed_position_ids);