From 8a7d43392971d09ce9eac3a96af34a1f67e1fdc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=96=91=EC=A2=85=EC=9B=90?= Date: Wed, 4 Sep 2024 13:31:15 +0900 Subject: [PATCH 1/3] [luci/service] Migrate Reshape shape inference rule to sinf::Algorithm This commit migrates Reshape shape inference rule to sinf::Algorithm. ONE-DCO-1.0-Signed-off-by: Jongwon Yang --- .../luci/Service/CircleShapeInference.h | 2 +- .../service/src/CircleShapeInferenceRule.cpp | 87 ------------ .../luci/service/src/Nodes/CircleReshape.cpp | 131 ++++++++++++++++++ .../service/src/Nodes/CircleReshape.test.cpp | 1 + 4 files changed, 133 insertions(+), 88 deletions(-) diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index 176390cf40b..8906983576f 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -122,7 +122,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleRelu0To1 *node) final; // loco::TensorShape visit(const luci::CircleRelu6 *node) final; // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final; - // loco::TensorShape visit(const luci::CircleReshape *node) final; + loco::TensorShape visit(const luci::CircleReshape *node) final; // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final; // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final; // loco::TensorShape visit(const luci::CircleReverseSequence *node) final; diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 2514533696f..c6a57b5b723 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -996,84 +996,6 @@ loco::NodeShape infer_range(const luci::CircleRange *node) return loco::NodeShape{output_shape}; } -loco::NodeShape infer_reshape(const luci::CircleReshape *node) -{ - LOGGER(l); - - const loco::DataType S32 = loco::DataType::S32; - - loco::TensorShape shape_by_input; - { - LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); - - // Only support node's shape() is CircleConst with S32 - // TODO support other node with other types - auto const_shape_node = dynamic_cast(node->shape()); - if (const_shape_node != nullptr) - { - LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); - - shape_by_input.rank(const_shape_node->size()); - - for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) - { - shape_by_input.dim(axis) = const_shape_node->at(axis); - } - } - else - { - // We use shape from the node itself - shape_by_input = own_shape(node); - } - } - - loco::TensorShape shape_by_attr; - { - shape_by_attr.rank(node->newShape()->rank()); - - for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) - { - shape_by_attr.dim(axis) = node->newShape()->dim(axis); - } - } - - if (!(shape_by_input == shape_by_attr)) - { - INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; - INFO(l) << " shape_by_input : " << shape_by_input << std::endl; - INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; - } - - loco::TensorShape output_shape = shape_by_input; - - // One of the dimensions can have special value -1, meaning its actual value should be inferred. - const auto input_shape = luci::shape_get(node->tensor()).as(); - uint32_t input_element_count = 1; - uint32_t output_element_count = 1; - uint32_t unknown_dim_index = UINT32_MAX; - for (uint32_t i = 0; i < input_shape.rank(); ++i) - input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) - { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (static_cast(dim_value) == -1) - { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; - } - else - { - output_element_count *= dim_value; - } - } - if (unknown_dim_index != UINT32_MAX) - { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; - } - - return loco::NodeShape{output_shape}; -} - template loco::NodeShape infer_resize_type(const CIRCLENODE *node) { auto input_shape = luci::shape_get(node->input()).template as(); @@ -2228,15 +2150,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor + +namespace +{ + +std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) +{ + os << "["; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + { + if (r) + os << ","; + + if (tensor_shape.dim(r).known()) + os << tensor_shape.dim(r).value(); + else + os << "?"; + } + os << "]"; + return os; +} + +} // namespace + namespace luci { @@ -34,4 +62,107 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleReshape *node) return cloned; } +namespace sinf +{ + +/** + * @note CircleReshape has new shape info in two places: 2nd input and attribute. + * This shape inference uses shape from input 'shape' node when it's constant. + * If not, shape will be from node itself. shape from attribute is not used. + * + * TODO Change this policy when not appropriate + */ +loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) +{ + LOGGER(l); + + const loco::DataType S32 = loco::DataType::S32; + + loco::TensorShape shape_by_input; + { + LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + + // Only support node's shape() is CircleConst with S32 + // TODO support other node with other types + auto const_shape_node = dynamic_cast(node->shape()); + if (const_shape_node != nullptr) + { + LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); + + shape_by_input.rank(const_shape_node->size()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + shape_by_input.dim(axis) = const_shape_node->at(axis); + } + } + else + { + // We use shape from the node itself + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t r = 0; r < node->rank(); ++r) + { + // Shape inference rules in this file did not consider unknown dimension. + // If some node has unknown dimension, 0 is inserted and wrong shape + // inference was done as a result. + // To fix this, new shape inference algorithm is being implemented. + // Until new inference algorithm is fully implemented, unknown dimension + // would be represented as 1 along with TFLite expression. + shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + } + shape_by_input = shape; + } + } + + loco::TensorShape shape_by_attr; + { + shape_by_attr.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) + { + shape_by_attr.dim(axis) = node->newShape()->dim(axis); + } + } + + if (!(shape_by_input == shape_by_attr)) + { + INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; + INFO(l) << " shape_by_input : " << shape_by_input << std::endl; + INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; + } + + loco::TensorShape output_shape = shape_by_input; + + // One of the dimensions can have special value -1, meaning its actual value should be inferred. + const auto input = loco::must_cast(node->tensor()); + const auto input_shape = circle_shape(input); + uint32_t input_element_count = 1; + uint32_t output_element_count = 1; + uint32_t unknown_dim_index = UINT32_MAX; + for (uint32_t i = 0; i < input_shape.rank(); ++i) + input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + { + const uint32_t dim_value = output_shape.dim(dim_index).value(); + if (static_cast(dim_value) == -1) + { + LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); + unknown_dim_index = dim_index; + } + else + { + output_element_count *= dim_value; + } + } + if (unknown_dim_index != UINT32_MAX) + { + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + } + + return output_shape; +} + +} // namespace sinf + } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index ca92b717d0c..908ac3042d7 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -15,6 +15,7 @@ */ #include "luci/Service/CircleNodeClone.h" +#include "luci/Service/CircleShapeInference.h" #include From 436119fdec0d93e65657734256241045393b88df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=96=91=EC=A2=85=EC=9B=90?= Date: Thu, 5 Sep 2024 13:43:51 +0900 Subject: [PATCH 2/3] [luci/service] Add test cases for reshape This commit adds test cases for reshape operation. ONE-DCO-1.0-Signed-off-by: Jongwon Yang --- .../service/src/Nodes/CircleReshape.test.cpp | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 908ac3042d7..39ed4b049d8 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -38,3 +38,119 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0)); ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } + +TEST(ShapeRuleTest, reshape_by_input_const_static) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +TEST(ShapeRuleTest, reshape_by_input_const_dynamic) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = -1; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +TEST(ShapeRuleTest, reshape_input_tensor_undefined_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::UNDEFINED); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape)); +} + +TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::UNDEFINED); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape)); +} From 0c2c5ac19b7d80b1c33780b3467c9192f7c38108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=96=91=EC=A2=85=EC=9B=90?= Date: Mon, 9 Sep 2024 16:51:28 +0900 Subject: [PATCH 3/3] [luci/service] Support dynamic shape inference for reshape This commit supports dynamic shape inference for reshape operation ONE-DCO-1.0-Signed-off-by: Jongwon Yang --- compiler/common-artifacts/exclude.lst | 1 + .../luci/service/src/Nodes/CircleReshape.cpp | 175 +++++++++--------- .../service/src/Nodes/CircleReshape.test.cpp | 105 +++++++++-- 3 files changed, 180 insertions(+), 101 deletions(-) diff --git a/compiler/common-artifacts/exclude.lst b/compiler/common-artifacts/exclude.lst index 4358bc02cdd..263b006e372 100644 --- a/compiler/common-artifacts/exclude.lst +++ b/compiler/common-artifacts/exclude.lst @@ -7,6 +7,7 @@ ## TensorFlowLiteRecipes optimize(Add_STR_000) # STRING is not supported optimize(Add_STR_001) # STRING is not supported +optimize(Net_Gather_SparseToDense_AddV2_000) # Constant folding is not generally supported ## CircleRecipes diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 080115bf2fe..e6314a061dd 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -20,29 +20,7 @@ #include "CircleShapeInferenceHelper.h" #include "CircleCloneNode.h" -#include - -namespace -{ - -std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) -{ - os << "["; - for (uint32_t r = 0; r < tensor_shape.rank(); ++r) - { - if (r) - os << ","; - - if (tensor_shape.dim(r).known()) - os << tensor_shape.dim(r).value(); - else - os << "?"; - } - os << "]"; - return os; -} - -} // namespace +#include namespace luci { @@ -66,98 +44,125 @@ namespace sinf { /** - * @note CircleReshape has new shape info in two places: 2nd input and attribute. - * This shape inference uses shape from input 'shape' node when it's constant. - * If not, shape will be from node itself. shape from attribute is not used. - * - * TODO Change this policy when not appropriate + * @note CircleReshape always has two inputs: tensor and shape. + * The shape input can be CircleConst, CircleOutputDummy, or CircleNode. + * - If the shape input is CircleConst, the shape is inferred from the constant. + * - If the shape input is CircleOutputDummy, the shape is inferred from + * the attribute if it exists. If the attribute does not exist, + * the shape is inferred from the node iteself. + * - If the shape input is CircleNode, the shape is not inferred. */ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { - LOGGER(l); - const loco::DataType S32 = loco::DataType::S32; - loco::TensorShape shape_by_input; + // CircleReshape node must have reshape/shape + if (node->shape() == nullptr) { - LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + INTERNAL_EXN("2nd input shape() should not be nullptr"); + } - // Only support node's shape() is CircleConst with S32 - // TODO support other node with other types - auto const_shape_node = dynamic_cast(node->shape()); - if (const_shape_node != nullptr) + bool should_infer = true; + loco::TensorShape output_shape; + { + // Check if reshape/shape is CircleConst + auto const_input = dynamic_cast(node->shape()); + if (const_input != nullptr) { - LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); + output_shape.rank(const_input->size()); - shape_by_input.rank(const_shape_node->size()); - - for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) { - shape_by_input.dim(axis) = const_shape_node->at(axis); + output_shape.dim(axis) = const_input->at(axis); + if (const_input->at(axis) < 0) + { + output_shape.dim(axis).unset(); + } } } else { - // We use shape from the node itself - loco::TensorShape shape; - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) + // Check if reshape/shape is CircleOutputDummy + auto dummy_input = dynamic_cast(node->shape()); + if (dummy_input != nullptr) { - // Shape inference rules in this file did not consider unknown dimension. - // If some node has unknown dimension, 0 is inserted and wrong shape - // inference was done as a result. - // To fix this, new shape inference algorithm is being implemented. - // Until new inference algorithm is fully implemented, unknown dimension - // would be represented as 1 along with TFLite expression. - shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + if (node->newShape()->rank() > 0) + { + output_shape.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) + { + output_shape.dim(axis) = node->newShape()->dim(axis); + if (node->newShape()->dim(axis) < 0) + { + output_shape.dim(axis).unset(); + } + } + } + else + { + output_shape = circle_shape(node); + } + } + else + { + // Check if reshape/shape is CircleNode + auto node_input = dynamic_cast(node->shape()); + if (node_input != nullptr) + { + output_shape.rank(node_input->dim(0).value()); + + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) + { + output_shape.dim(axis).unset(); + } + + should_infer = false; + } } - shape_by_input = shape; - } - } - - loco::TensorShape shape_by_attr; - { - shape_by_attr.rank(node->newShape()->rank()); - - for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) - { - shape_by_attr.dim(axis) = node->newShape()->dim(axis); } } - if (!(shape_by_input == shape_by_attr)) - { - INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; - INFO(l) << " shape_by_input : " << shape_by_input << std::endl; - INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; - } - - loco::TensorShape output_shape = shape_by_input; - - // One of the dimensions can have special value -1, meaning its actual value should be inferred. const auto input = loco::must_cast(node->tensor()); const auto input_shape = circle_shape(input); uint32_t input_element_count = 1; - uint32_t output_element_count = 1; - uint32_t unknown_dim_index = UINT32_MAX; - for (uint32_t i = 0; i < input_shape.rank(); ++i) - input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + for (uint32_t axis = 0; axis < input_shape.rank(); ++axis) { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (static_cast(dim_value) == -1) + if (input_shape.dim(axis).known()) { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; + input_element_count *= input_shape.dim(axis).value(); } else { - output_element_count *= dim_value; + should_infer = false; + break; } } - if (unknown_dim_index != UINT32_MAX) + + if (should_infer) { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + uint32_t output_element_count = 1; + uint32_t unknown_dim_index = UINT32_MAX; + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + { + if (output_shape.dim(dim_index).known() == false) + { + if (unknown_dim_index != UINT32_MAX) + { + INTERNAL_EXN("More than one unknown dimension"); + } + unknown_dim_index = dim_index; + } + else + { + const uint32_t dim_value = output_shape.dim(dim_index).value(); + output_element_count *= dim_value; + } + } + if (unknown_dim_index != UINT32_MAX) + { + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + } } return output_shape; diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 39ed4b049d8..07aaed427a7 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -39,25 +39,25 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } -TEST(ShapeRuleTest, reshape_by_input_const_static) +TEST(ShapeRuleTest, reshape_by_circle_const) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create(); auto tensor_input = g->nodes()->create(); - auto shape_by_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); tensor_input->dtype(loco::DataType::S32); tensor_input->shape({2, 3, 4}); tensor_input->shape_status(luci::ShapeStatus::VALID); - shape_by_input->dtype(loco::DataType::S32); - shape_by_input->size(2); - shape_by_input->at(0) = 6; - shape_by_input->at(1) = 4; - shape_by_input->shape_status(luci::ShapeStatus::VALID); + shape_input->dtype(loco::DataType::S32); + shape_input->size(2); + shape_input->at(0) = -1; + shape_input->at(1) = 4; + shape_input->shape_status(luci::ShapeStatus::VALID); node_reshape->tensor(tensor_input); - node_reshape->shape(shape_by_input); + node_reshape->shape(shape_input); loco::TensorShape output_shape; luci::sinf::Rule shape_inf_rule; @@ -71,25 +71,25 @@ TEST(ShapeRuleTest, reshape_by_input_const_static) ASSERT_EQ(4, output_shape.dim(1).value()); } -TEST(ShapeRuleTest, reshape_by_input_const_dynamic) +TEST(ShapeRuleTest, reshape_by_circle_dummy) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create(); auto tensor_input = g->nodes()->create(); - auto shape_by_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); tensor_input->dtype(loco::DataType::S32); tensor_input->shape({2, 3, 4}); tensor_input->shape_status(luci::ShapeStatus::VALID); - shape_by_input->dtype(loco::DataType::S32); - shape_by_input->size(2); - shape_by_input->at(0) = -1; - shape_by_input->at(1) = 4; - shape_by_input->shape_status(luci::ShapeStatus::VALID); + shape_input->dtype(loco::DataType::S32); + shape_input->shape_status(luci::ShapeStatus::VALID); node_reshape->tensor(tensor_input); - node_reshape->shape(shape_by_input); + node_reshape->shape(shape_input); + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = -1; + node_reshape->newShape()->dim(1) = 4; loco::TensorShape output_shape; luci::sinf::Rule shape_inf_rule; @@ -103,6 +103,34 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic) ASSERT_EQ(4, output_shape.dim(1).value()); } +TEST(ShapeRuleTest, reshape_by_circle_node) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_input->dtype(loco::DataType::S32); + shape_input->shape({2}); + shape_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_FALSE(output_shape.dim(0).known()); + ASSERT_FALSE(output_shape.dim(1).known()); +} + TEST(ShapeRuleTest, reshape_input_tensor_undefined_NEG) { auto g = loco::make_graph(); @@ -154,3 +182,48 @@ TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG) ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape)); } + +TEST(ShapeRuleTest, reshape_no_input_shape_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(nullptr); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(node_reshape, output_shape)); +} + +TEST(ShapeRuleTest, reshape_too_many_unknown_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_input->dtype(loco::DataType::S32); + shape_input->size(2); + shape_input->at(0) = -1; + shape_input->at(1) = -1; + shape_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_ANY_THROW(shape_inf_rule.infer(node_reshape, output_shape)); +}