Skip to content

Commit

Permalink
Update scaled_dot_product_decomposition_test.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem authored Nov 29, 2023
1 parent 75893c3 commit 764e734
Showing 1 changed file with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(
const bool casual);

TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStatic) {
const PartialShape query_key_value_mask_shape{1, 32, 32};
const PartialShape scale_shape{};
const PartialShape query_key_value_mask_scale__shape{32, 32, 32};

const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto scale = std::make_shared<ov::op::v0::Parameter>(element::f32, scale_shape);
const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto scale = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto casual = false;
{
const auto scaled_dot_product_attention =
Expand All @@ -71,11 +70,11 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
const PartialShape query_key_value_mask_shape{-1, -1, -1};
const PartialShape scale_shape{};

const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_shape);
const auto scale = std::make_shared<ov::op::v0::Parameter>(element::f32, scale_shape);
const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto scale = std::make_shared<ov::op::v0::Parameter>(element::f32, query_key_value_mask_scale__shape);
const auto casual = false;
{
const auto scaled_dot_product_attention =
Expand Down

0 comments on commit 764e734

Please sign in to comment.