Skip to content

Commit

Permalink
Tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 18, 2024
1 parent a2e4bf2 commit 8a355f3
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,14 @@ void regmodule_offline_transformations(py::module m) {

m_offline_transformations.def(
"paged_attention_transformation",
[](std::shared_ptr<ov::Model> model, bool use_block_indices_inputs, bool use_score_outputs, bool allow_cache_rotation) {
[](std::shared_ptr<ov::Model> model,
bool use_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs, allow_cache_rotation);
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs,
use_score_outputs,
allow_cache_rotation);
manager.run_passes(model);
},
py::arg("model"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
return node_tuple(kv_past_par, kv_current2, kv_current_reshaped, kv_concat);
}

template<class T>
template <class T>
void insert_rotation_inputs_as(OutputVector& pa_arguments, size_t layer_index) {
auto rotation_coefficients = setName(std::make_shared<T>(ov::element::f32, ov::PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
Expand Down Expand Up @@ -194,8 +194,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
&score_results,
&layer_index,
&rotation_coefficients_inputs_for_each_layer,
&rotated_block_indices_inputs_for_each_layer
](ov::pass::pattern::Matcher& m) {
&rotated_block_indices_inputs_for_each_layer](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto real_q = pattern_map.at(q);

Expand Down Expand Up @@ -400,8 +399,6 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par

OPENVINO_ASSERT(pa_arguments.size() == 13);



if (allow_cache_rotation) {
auto rotation_coefficients = setName(std::make_shared<v0::Parameter>(element::f32, PartialShape{-1}),
"rotation_coefficients." + std::to_string(layer_index - 1));
Expand Down
3 changes: 2 additions & 1 deletion src/core/include/openvino/pass/sdpa_to_paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass {
public:
OPENVINO_RTTI("SDPAToPagedAttention");

SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false, bool use_score_outputs = false,
SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false,
bool use_score_outputs = false,
bool allow_cache_rotation = false);
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

Expand Down
23 changes: 10 additions & 13 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,33 +149,30 @@ void PagedAttentionExtension::validate_and_infer_types() {

if (get_input_size() == 15) {
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(13).rank().is_dynamic() ||
get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().get_length(),
".");
this,
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotation_coefficients` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::f32,
"Element type of `rotation_coefficients` input should be f32, but it is ",
get_input_element_type(13),
".");

NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(14).rank().is_dynamic() ||
get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(14).rank().get_length(),
".");
this,
get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(14).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32,
"Element type of `rotated_block_indices` input should be i32, but it is ",
get_input_element_type(14),
".");
}


// value head_size may be not same with key
auto out_ps = get_input_partial_shape(0);
const auto& key_ps = get_input_partial_shape(1);
Expand Down
3 changes: 2 additions & 1 deletion src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

using namespace ov::op;

ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, bool use_score_outputs,
ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation)
: m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs),
m_use_score_outputs(use_score_outputs),
Expand Down
35 changes: 30 additions & 5 deletions src/core/tests/type_prop/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/paged_attention.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/openvino.hpp"
#include "openvino/opsets/opset13.hpp"

Expand All @@ -28,8 +29,19 @@ TEST(type_prop, paged_attention_static_13_inputs) {
const auto alibi_slopes = std::make_shared<opset13::Parameter>(element::f32, Shape{9});
const auto max_context_len = std::make_shared<opset13::Parameter>(element::i32, Shape{});


ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len};
ov::OutputVector args = {query,
key,
value,
key_cache,
value_cache,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len};
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (Shape{3, 4}));
Expand All @@ -53,10 +65,23 @@ TEST(type_prop, paged_attention_static_15_inputs) {
const auto rotation_coefficients = std::make_shared<opset13::Parameter>(element::f32, Shape{12});
const auto rotated_block_indices = std::make_shared<opset13::Parameter>(element::i32, Shape{3});

ov::OutputVector args = {query, key, value, key_cache, value_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, scale, sliding_window, alibi_slopes, max_context_len, rotation_coefficients, rotated_block_indices};
ov::OutputVector args = {query,
key,
value,
key_cache,
value_cache,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
rotation_coefficients,
rotated_block_indices};

const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4}));
}

Loading

0 comments on commit 8a355f3

Please sign in to comment.