Skip to content

Commit

Permalink
preliminary fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Jan 24, 2025
1 parent cc05aad commit 437dc90
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
#include "transformations/utils/utils.hpp"

#include "openvino/pass/visualize_tree.hpp"

using namespace ov::op;

ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs,
Expand Down Expand Up @@ -90,12 +92,21 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode

OPENVINO_ASSERT(input_ids_node, "The model doesn't contain input_ids or input_embeds input. Aborting.");

input_ids_node->set_partial_shape(PartialShape{-1});
auto input_ids_target_inputs = input_ids_node->get_output_target_inputs(0);
auto unsqueezed_input_ids =
std::make_shared<v0::Unsqueeze>(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1}));
for (const auto& target : input_ids_target_inputs) {
target.replace_source_output(unsqueezed_input_ids);
std::shared_ptr<ov::Node> processed_input_ids;
if (input_ids_node->get_friendly_name() == "input_ids") {
auto input_ids_target_inputs = input_ids_node->get_output_target_inputs(0);
input_ids_node->set_partial_shape(PartialShape{-1});
processed_input_ids =
std::make_shared<v0::Unsqueeze>(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1}));
for (const auto& target : input_ids_target_inputs) {
target.replace_source_output(processed_input_ids);
}
} else if (input_ids_node->get_friendly_name() == "inputs_embeds") {
// VLMs have the input_ids part + embeddings calculation
// served as "inputs_embeds" input, so there's no need
// for additional work on the input here as this is done
// for "input_ids"
processed_input_ids = input_ids_node;
}

ParameterVector kv_parameters;
Expand Down Expand Up @@ -141,7 +152,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
rotated_block_indices_inputs_for_each_layer,
rotation_deltas_inputs_for_each_layer,
model_rotation_trig_lut);
manager.register_pass<PrevSequenceLengthPattern>(unsqueezed_input_ids, max_context_len, position_ids);
manager.register_pass<PrevSequenceLengthPattern>(processed_input_ids, max_context_len, position_ids);
manager.register_pass<TotalSequenceLengthPattern>(max_context_len);
manager.register_pass<TotalSequenceLengthPatternQwen>(max_context_len);
manager.register_pass<PositionIDsReplacer>(unsqueezed_position_ids);
Expand Down

0 comments on commit 437dc90

Please sign in to comment.