Skip to content

Commit

Permalink
add concat shape inference test with op reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
barnasm1 committed May 6, 2024
1 parent 59c54cb commit 0085ea1
Showing 1 changed file with 52 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,55 @@ TEST_P(ConcatStaticShapeInferenceTest, concat_static) {

ASSERT_EQ(output_shapes.front(), exp_shape);
}

TEST(ConcatStaticShapeInferenceTest, concat_static_op_reuse_1_arg) {
auto param = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto op = std::make_shared<op::v0::Concat>(NodeVector{1, param}, -1);

auto output_shapes = shape_inference(op.get(), ShapeVector{{4, 2, 1}});
ASSERT_EQ(output_shapes.front(), StaticShape({4, 2, 1}));

output_shapes = shape_inference(op.get(), ShapeVector{{1, 2, 0, 4, 5}});
ASSERT_EQ(output_shapes.front(), StaticShape({1, 2, 0, 4, 5}));
}

TEST(ConcatStaticShapeInferenceTest, concat_static_op_reuse_2_arg) {
auto param = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto op = std::make_shared<op::v0::Concat>(NodeVector{2, param}, -3);

auto output_shapes = shape_inference(op.get(), ShapeVector{{4, 2, 1}, {4, 2, 1}});
ASSERT_EQ(output_shapes.front(), StaticShape({8, 2, 1}));

output_shapes = shape_inference(op.get(), ShapeVector{{1, 2, 0, 4, 5}, {1, 2, 9, 4, 5}});
ASSERT_EQ(output_shapes.front(), StaticShape({1, 2, 9, 4, 5}));
}

TEST(ConcatStaticShapeInferenceTest, concat_static_op_reuse_2_arg_with_exception) {
auto param = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto op = std::make_shared<op::v0::Concat>(NodeVector{2, param}, -3);

auto output_shapes = shape_inference(op.get(), ShapeVector{{4, 2, 1}, {4, 2, 1}});
ASSERT_EQ(output_shapes.front(), StaticShape({8, 2, 1}));

EXPECT_THROW(shape_inference(op.get(), ShapeVector{{4}, {0}}), ov::AssertFailure);

output_shapes = shape_inference(op.get(), ShapeVector{{1, 2, 0, 4, 5}, {1, 2, 9, 4, 5}});
ASSERT_EQ(output_shapes.front(), StaticShape({1, 2, 9, 4, 5}));
}

TEST(ConcatStaticShapeInferenceTest, concat_static_op_reuse_3_arg) {
auto param = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto op = std::make_shared<op::v0::Concat>(NodeVector{3, param}, -1);

auto output_shapes = shape_inference(op.get(), ShapeVector{{4}, {0}, {2}});
ASSERT_EQ(output_shapes.front(), StaticShape({6}));

output_shapes = shape_inference(op.get(), ShapeVector{{2, 1}, {2, 1}, {2, 1}});
ASSERT_EQ(output_shapes.front(), StaticShape({2, 3}));

output_shapes = shape_inference(op.get(), ShapeVector{{4, 2, 5}, {4, 2, 1}, {4, 2, 2}});
ASSERT_EQ(output_shapes.front(), StaticShape({4, 2, 8}));

output_shapes = shape_inference(op.get(), ShapeVector{{1, 2, 3, 4, 3}, {1, 2, 3, 4, 1}, {1, 2, 3, 4, 1}});
ASSERT_EQ(output_shapes.front(), StaticShape({1, 2, 3, 4, 5}));
}

0 comments on commit 0085ea1

Please sign in to comment.