From 17081b435a1d4a7b7a9e73d93658179a154f6991 Mon Sep 17 00:00:00 2001 From: Hyukjin Jeong Date: Wed, 18 Dec 2024 18:51:59 +0900 Subject: [PATCH] [luci] Fix ForwardRehsapeToUnaryOpPass bug (#14474) This updates rank of constant after reshape is forwarded. ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong --- .../pass/src/ForwardReshapeToUnaryOpPass.cpp | 13 +++ .../src/ForwardReshapeToUnaryOpPass.test.cpp | 106 ++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp index 21ac7adfbc7..cba8d35b3a7 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -186,16 +186,29 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleMul *div, if (not new_reshape) return false; + const auto prev = loco::must_cast(reshape->tensor()); + + // Reshape can change rank of tensor, so we need to update constant value accordingly. + assert(const_value->size() == 1); + auto cloned_const = clone(const_value); + cloned_const->rank(prev->rank()); + for (uint32_t i = 0; i < prev->rank(); ++i) + { + cloned_const->dim(i).set(1); + } + // reconnect network loco::replace(div).with(new_reshape); if (div->x() == const_value) { + div->x(cloned_const); div->y(reshape->tensor()); } else { assert(div->y() == const_value); div->x(reshape->tensor()); + div->y(cloned_const); } new_reshape->tensor(div); diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp index ae89e4fad38..eea48efbc7c 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp @@ -101,6 +101,56 @@ class ReshapeLogisticGraphlet luci::CircleConst *_reshape_shape = nullptr; }; +class ReshapeMulGraphlet +{ +public: + ReshapeMulGraphlet() = default; + +public: + void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out) + { + std::vector shape_out_v = shape_out; + + _reshape_shape = g->nodes()->create(); + _reshape = g->nodes()->create(); + _mul = g->nodes()->create(); + _const = g->nodes()->create(); + + _reshape_shape->dtype(loco::DataType::S32); + _reshape_shape->rank(1); + _reshape_shape->dim(0).set(shape_out_v.size()); + _reshape_shape->shape_status(luci::ShapeStatus::VALID); + + // values + const auto size = shape_out_v.size(); + _reshape_shape->size(size); + for (uint32_t i = 0; i < size; i++) + _reshape_shape->at(i) = shape_out_v[i]; + + _const->dtype(loco::DataType::FLOAT32); + _const->rank(size); + uint32_t numel = 1; + for (uint32_t i = 0; i < size; i++) + { + _const->dim(i).set(1); + } + _const->size(1); + _const->at(0) = 1.0; + _const->shape_status(luci::ShapeStatus::VALID); + + _reshape_shape->name("reshape_shape"); + _reshape->name("reshape"); + _mul->name("mul"); + _const->name("const"); + } + +protected: + luci::CircleReshape *_reshape = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_const = nullptr; + luci::CircleConst *_reshape_shape = nullptr; +}; + class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet { public: @@ -141,6 +191,27 @@ class ForwardReshapeToLogisticGraph : public TestIOGraph, public ReshapeLogistic } }; +class ForwardReshapeToMulGraph : public TestIOGraph, public ReshapeMulGraphlet +{ +public: + ForwardReshapeToMulGraph() = default; + +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + ReshapeMulGraphlet::init(g(), shape_in, shape_out); + + // connect network + _reshape->tensor(input()); + _reshape->shape(_reshape_shape); + _mul->x(_reshape); + _mul->y(_const); + + output()->from(_mul); + } +}; + class ForwardReshapeToNegGraphTest : public ::testing::Test { public: @@ -173,6 +244,22 @@ class ForwardReshapeToLogisticGraphTest : public ::testing::Test luci::ForwardReshapeToUnaryOpPass _pass; }; +class ForwardReshapeToMulGraphTest : public ::testing::Test +{ +public: + ForwardReshapeToMulGraphTest() = default; + + void run_pass(void) + { + while (_pass.run(_graph.g())) + ; + } + +protected: + ForwardReshapeToMulGraph _graph; + luci::ForwardReshapeToUnaryOpPass _pass; +}; + /** * Simple graph for test * @@ -318,6 +405,25 @@ TEST_F(ForwardReshapeToLogisticGraphTest, forward) ASSERT_NE(nullptr, log); } +TEST_F(ForwardReshapeToMulGraphTest, forward_rank_update) +{ + _graph.init({1, 2, 3}, {1, 1, 2, 3}); + + run_pass(); + + auto reshape = dynamic_cast(_graph.output()->from()); + ASSERT_NE(nullptr, reshape); + + auto mul = dynamic_cast(reshape->tensor()); + ASSERT_NE(nullptr, mul); + + // Check mul's const rank == input rank (3) + auto const_mul = dynamic_cast(mul->y()); + ASSERT_NE(nullptr, const_mul); + + ASSERT_EQ(3, const_mul->rank()); +} + TEST(FuseMulWithDivPassTest, forward_reshape_to_mean_pattern) { ForwardReshapeToMeanPatternTestGraph g;