From 8fb5db663b0bf09bd57403e2a72bb794d6ebc580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=96=91=EC=A2=85=EC=9B=90?= Date: Thu, 5 Sep 2024 13:43:51 +0900 Subject: [PATCH] [WIP][luci/service] Add test cases for reshape This commit adds test cases for reshape operation. ONE-DCO-1.0-Signed-off-by: Jongwon Yang --- .../service/src/Nodes/CircleReshape.test.cpp | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 908ac3042d7..c294e907a7c 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -38,3 +38,128 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0)); ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } + + +TEST(ShapeRuleTest, reshape_by_const_static_input) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +TEST(ShapeRuleTest, reshape_by_const_dynamic_input) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = -1; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +// TEST(ShapeRuleTest, reshape_by_node_static_input) +// { +// auto g = loco::make_graph(); +// auto node_reshape = g->nodes()->create(); +// auto tensor_input = g->nodes()->create(); +// auto shape_by_input = g->nodes()->create(); + +// tensor_input->dtype(loco::DataType::S32); +// tensor_input->shape({2, 3, 4}); +// tensor_input->shape_status(luci::ShapeStatus::VALID); + +// shape_by_input->dtype(loco::DataType::S32); +// shape_by_input->shape({6, 4}); +// shape_by_input->shape_status(luci::ShapeStatus::VALID); + +// node_reshape->tensor(tensor_input); +// node_reshape->shape(shape_by_input); + +// loco::TensorShape output_shape; +// luci::sinf::Rule shape_inf_rule; + +// ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + +// ASSERT_EQ(2, output_shape.rank()); +// ASSERT_TRUE(output_shape.dim(0).known()); +// ASSERT_TRUE(output_shape.dim(1).known()); +// ASSERT_EQ(6, output_shape.dim(0).value()); +// ASSERT_EQ(4, output_shape.dim(1).value()); +// } + +// TEST(ShapeRuleTest, reshape_by_node_dynamic_input) +// { +// auto g = loco::make_graph(); +// auto node_reshape = g->nodes()->create(); +// auto tensor_input = g->nodes()->create(); +// auto shape_by_input = g->nodes()->create(); + +// tensor_input->dtype(loco::DataType::S32); +// tensor_input->shape({2, 3, 4}); +// tensor_input->shape_status(luci::ShapeStatus::VALID); + +// shape_by_input->dtype(loco::DataType::S32); +// shape_by_input->rank(2); +// shape_by_input->dim(0).unset(); +// shape_by_input->dim(1).unset(); +// shape_by_input->shape_status(luci::ShapeStatus::VALID); + +// node_reshape->tensor(tensor_input); +// node_reshape->shape(shape_by_input); + +// loco::TensorShape output_shape; +// luci::sinf::Rule shape_inf_rule; + +// ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + +// ASSERT_EQ(2, output_shape.rank()); +// ASSERT_FALSE(output_shape.dim(0).known()); +// ASSERT_FALSE(output_shape.dim(1).known()); +// } \ No newline at end of file