Skip to content

Commit

Permalink
[luci/service] Handle CircleNode as reshape's shape (Samsung#14128)
Browse files Browse the repository at this point in the history
This commit supports handling of CircleNode as reshape's shape.
This is a part of the new shape inference policy of reshape.

ONE-DCO-Signed-off-by: Jongwon Yang <[email protected]>
Co-authored-by: SaeHie Park <[email protected]>
  • Loading branch information
jongwonyang and seanshpark authored Sep 30, 2024
1 parent b8fe99d commit bf49fed
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
50 changes: 23 additions & 27 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)

const loco::DataType S32 = loco::DataType::S32;

bool is_static_shape = true;

loco::TensorShape shape_by_input;
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
Expand All @@ -95,21 +97,12 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
}
else
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
}
shape_by_input = shape;
auto shape_node = loco::must_cast<luci::CircleNode *>(node->shape());
assert(shape_node->rank() == 1);
// shape_node tensor values will provide new shape, like [2, 3, 4]
auto num_elements = shape_node->dim(0).value(); // above example will give 3
shape_by_input.rank(num_elements);
is_static_shape = false;
}
}

Expand Down Expand Up @@ -138,31 +131,34 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
bool is_static_shape = true;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
{
if (input_shape.dim(i).known())
input_element_count *= input_shape.dim(i).value();
else
is_static_shape = false;
}
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)

if (is_static_shape)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (not output_shape.dim(dim_index).known())
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (not output_shape.dim(dim_index).known())
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
}
else
{
output_element_count *= dim_value;
}
}
else
if (unknown_dim_index != UINT32_MAX)
{
output_element_count *= dim_value;
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
}
if (unknown_dim_index != UINT32_MAX && is_static_shape)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}

return output_shape;
}
Expand Down
28 changes: 28 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,31 @@ TEST(ShapeRuleTest, reshape_should_infer)
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_input_node)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleInput>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->shape({2});
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_FALSE(output_shape.dim(0).known());
ASSERT_FALSE(output_shape.dim(1).known());
}

0 comments on commit bf49fed

Please sign in to comment.