Skip to content

Commit

Permalink
[luci/service] Migrate Reshape shape inference rule to sinf::Algorithm
Browse files Browse the repository at this point in the history
This commit migrates Reshape shape inference rule to sinf::Algorithm.

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 11, 2024
1 parent c3a9c0b commit 82d2717
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleRelu0To1 *node) final;
// loco::TensorShape visit(const luci::CircleRelu6 *node) final;
// loco::TensorShape visit(const luci::CircleReluN1To1 *node) final;
// loco::TensorShape visit(const luci::CircleReshape *node) final;
loco::TensorShape visit(const luci::CircleReshape *node) final;
// loco::TensorShape visit(const luci::CircleResizeBilinear *node) final;
// loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final;
// loco::TensorShape visit(const luci::CircleReverseSequence *node) final;
Expand Down
87 changes: 0 additions & 87 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,84 +996,6 @@ loco::NodeShape infer_range(const luci::CircleRange *node)
return loco::NodeShape{output_shape};
}

loco::NodeShape infer_reshape(const luci::CircleReshape *node)
{
LOGGER(l);

const loco::DataType S32 = loco::DataType::S32;

loco::TensorShape shape_by_input;
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");

shape_by_input.rank(const_shape_node->size<S32>());

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
}
}
else
{
// We use shape from the node itself
shape_by_input = own_shape(node);
}
}

loco::TensorShape shape_by_attr;
{
shape_by_attr.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
{
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
}
}

if (!(shape_by_input == shape_by_attr))
{
INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
}

loco::TensorShape output_shape = shape_by_input;

// One of the dimensions can have special value -1, meaning its actual value should be inferred.
const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}

return loco::NodeShape{output_shape};
}

template <class CIRCLENODE> loco::NodeShape infer_resize_type(const CIRCLENODE *node)
{
auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
Expand Down Expand Up @@ -2228,15 +2150,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS
return loco::NodeShape{input_shape};
}

/**
* @note CircleReshape has new shape info in two places: 2nd input and attribute.
* This shape inference uses shape from input 'shape' node when it's constant.
* If not, shape will be from node itself. shape from attribute is not used.
*
* TODO Change this policy when not appropriate
*/
loco::NodeShape visit(const luci::CircleReshape *node) final { return infer_reshape(node); }

loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
{
return infer_resize_type(node);
Expand Down
125 changes: 125 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,36 @@
* limitations under the License.
*/

#include "luci/Service/CircleShapeInference.h"
#include "Check.h"

#include "CircleShapeInferenceHelper.h"
#include "CircleCloneNode.h"

#include <luci/Log.h>

namespace
{

std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
{
os << "[";
for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
{
if (r)
os << ",";

if (tensor_shape.dim(r).known())
os << tensor_shape.dim(r).value();
else
os << "?";
}
os << "]";
return os;
}

} // namespace

namespace luci
{

Expand All @@ -34,4 +62,101 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
return cloned;
}

namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
LOGGER(l);

const loco::DataType S32 = loco::DataType::S32;

loco::TensorShape shape_by_input;
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");

shape_by_input.rank(const_shape_node->size<S32>());

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
}
}
else
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
}
shape_by_input = shape;
}
}

loco::TensorShape shape_by_attr;
{
shape_by_attr.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
{
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
}
}

if (!(shape_by_input == shape_by_attr))
{
INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
}

loco::TensorShape output_shape = shape_by_input;

// One of the dimensions can have special value -1, meaning its actual value should be inferred.
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
const auto input_shape = circle_shape(input);
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}

return output_shape;
}

} // namespace sinf

} // namespace luci
65 changes: 65 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "luci/Service/CircleNodeClone.h"
#include "luci/Service/CircleShapeInference.h"

#include <gtest/gtest.h>

Expand All @@ -37,3 +38,67 @@ TEST(CloneNodeTest, clone_Reshape)
ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0));
ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1));
}

TEST(ShapeRuleTest, reshape_by_input_const_static)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_TRUE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = -1;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_TRUE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

0 comments on commit 82d2717

Please sign in to comment.