diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 0de10960b51..bac8b874e91 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -65,6 +65,13 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleReshape *node) namespace sinf { +/** + * @note CircleReshape always has two inputs: `tensor` and `shape`. + * The `shape` can be CircleConst, CircleOutputDummy, or CircleNode. + * - If the `shape` is CircleConst, the shape is inferred from the constant. + * - Else, the shape is inferred from the node iteself. + * - TODO support CircleOutputDummy and CircleNode + */ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { LOGGER(l); @@ -77,8 +84,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) // Only support node's shape() is CircleConst with S32 // TODO support other node with other types - auto const_shape_node = dynamic_cast(node->shape()); - if (const_shape_node != nullptr) + if (auto const_shape_node = dynamic_cast(node->shape())) { LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); @@ -87,6 +93,10 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) { shape_by_input.dim(axis) = const_shape_node->at(axis); + if (const_shape_node->at(axis) < 0) + { + shape_by_input.dim(axis).unset(); + } } } else @@ -139,7 +149,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (static_cast(dim_value) == -1) + if (output_shape.dim(dim_index).known() == false) { LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); unknown_dim_index = dim_index; diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index a6ae6735500..50ce3aaa3f6 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -39,7 +39,7 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } -TEST(ShapeRuleTest, reshape_by_input_const_static) +TEST(ShapeRuleTest, reshape_by_const_static) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create(); @@ -71,7 +71,7 @@ TEST(ShapeRuleTest, reshape_by_input_const_static) ASSERT_EQ(4, output_shape.dim(1).value()); } -TEST(ShapeRuleTest, reshape_by_input_const_dynamic) +TEST(ShapeRuleTest, reshape_by_const_dynamic) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create();