diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index a28ad648320..28eb6303735 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -71,6 +71,8 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) const loco::DataType S32 = loco::DataType::S32; + bool is_static_shape = true; + loco::TensorShape shape_by_input; { LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); @@ -95,21 +97,12 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) } else { - // We use shape from the node itself - loco::TensorShape shape; - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) - { - // TODO remove this copy from `use_own(node);` - // Shape inference rules in this file did not consider unknown dimension. - // If some node has unknown dimension, 0 is inserted and wrong shape - // inference was done as a result. - // To fix this, new shape inference algorithm is being implemented. - // Until new inference algorithm is fully implemented, unknown dimension - // would be represented as 1 along with TFLite expression. - shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; - } - shape_by_input = shape; + auto shape_node = loco::must_cast(node->shape()); + assert(shape_node->rank() == 1); + // shape_node tensor values will provide new shape, like [2, 3, 4] + auto num_elements = shape_node->dim(0).value(); // above example will give 3 + shape_by_input.rank(num_elements); + is_static_shape = false; } } @@ -138,7 +131,6 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) uint32_t input_element_count = 1; uint32_t output_element_count = 1; uint32_t unknown_dim_index = UINT32_MAX; - bool is_static_shape = true; for (uint32_t i = 0; i < input_shape.rank(); ++i) { if (input_shape.dim(i).known()) @@ -146,23 +138,27 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) else is_static_shape = false; } - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + + if (is_static_shape) { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (not output_shape.dim(dim_index).known()) + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; + const uint32_t dim_value = output_shape.dim(dim_index).value(); + if (not output_shape.dim(dim_index).known()) + { + LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); + unknown_dim_index = dim_index; + } + else + { + output_element_count *= dim_value; + } } - else + if (unknown_dim_index != UINT32_MAX) { - output_element_count *= dim_value; + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; } } - if (unknown_dim_index != UINT32_MAX && is_static_shape) - { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; - } return output_shape; } diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 65382ce1f8e..4bb13edc2f9 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -134,3 +134,31 @@ TEST(ShapeRuleTest, reshape_should_infer) ASSERT_TRUE(output_shape.dim(1).known()); ASSERT_EQ(4, output_shape.dim(1).value()); } + +TEST(ShapeRuleTest, reshape_by_input_node) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->shape({2}); + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_FALSE(output_shape.dim(0).known()); + ASSERT_FALSE(output_shape.dim(1).known()); +}