Skip to content

Commit

Permalink
[WIP][luci/service] Support dynamic shape inference for reshape
Browse files Browse the repository at this point in the history
This commit supports dynamic shape inference for reshape operation

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 9, 2024
1 parent 436119f commit 7f05707
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 67 deletions.
143 changes: 76 additions & 67 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "CircleCloneNode.h"

#include <luci/Log.h>
#include <oops/InternalExn.h>

namespace
{
Expand Down Expand Up @@ -66,100 +67,108 @@ namespace sinf
{

/**
* @note CircleReshape has new shape info in two places: 2nd input and attribute.
* This shape inference uses shape from input 'shape' node when it's constant.
* If not, shape will be from node itself. shape from attribute is not used.
*
* TODO Change this policy when not appropriate
* @note CircleReshape always has two inputs: tensor and shape.
* The shape input can be CircleOutputDummy, CircleConst, or CircleNode.
* If the shape input is CircleOutputDummy, the shape is inferred from the attribute.
* If the shape input is CircleConst, the shape is inferred from the constant.
* If the shape input is CircleNode, the shape is not inferred.
*/
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
LOGGER(l);

const loco::DataType S32 = loco::DataType::S32;

loco::TensorShape shape_by_input;
// CircleReshape node must have reshape/shape
if (node->shape() == nullptr)
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
INTERNAL_EXN("2nd input shape() should not be nullptr");
}

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
bool should_infer = true;
loco::TensorShape output_shape;
{
// Check if shape() is CircleOutputDummy
auto dummy_input = dynamic_cast<luci::CircleOutputDummy *>(node->shape());
if (dummy_input != nullptr)
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");

shape_by_input.rank(const_shape_node->size<S32>());
// Try to get shape from attribute
if (node->newShape())
{
output_shape.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis) = node->newShape()->dim(axis);
}
}
else
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
// Or, use shape from the node itself
output_shape = circle_shape(node);
}
}
else
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
// If shape() is not CircleOutputDummy, it should be CircleConst or CircleNode
// Check if shape() is CircleConst
auto const_input = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_input != nullptr)
{
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
}
shape_by_input = shape;
}
}

loco::TensorShape shape_by_attr;
{
shape_by_attr.rank(node->newShape()->rank());
output_shape.rank(const_input->size<S32>());

for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
{
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis) = const_input->at<S32>(axis);
}
}
else
{
// If shape() is not CircleConst, it should be CircleNode
auto node_input = dynamic_cast<luci::CircleNode *>(node->shape());
if (node_input != nullptr)
{
output_shape.rank(node_input->dim(0).value());

for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis).unset();
}

should_infer = false;
}
}
}
}

if (!(shape_by_input == shape_by_attr))
if (should_infer)
{
INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
}

loco::TensorShape output_shape = shape_by_input;

// One of the dimensions can have special value -1, meaning its actual value should be inferred.
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
const auto input_shape = circle_shape(input);
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
const auto input_shape = circle_shape(input);
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
output_element_count *= dim_value;
}
}
else
if (unknown_dim_index != UINT32_MAX)
{
output_element_count *= dim_value;
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
}
if (unknown_dim_index != UINT32_MAX)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}


return output_shape;
}

Expand Down
19 changes: 19 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG)

ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape));
}

TEST(ShapeRuleTest, reshape_no_2nd_input_NEG)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(nullptr);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(node_reshape, output_shape));
}

0 comments on commit 7f05707

Please sign in to comment.