From 11199a74c5738c63315f0377b0e1237ce2e152c1 Mon Sep 17 00:00:00 2001 From: Hyukjin Jeong Date: Fri, 30 Aug 2024 18:25:13 +0900 Subject: [PATCH] [luci] Forward transpose across single element MUL (#13859) This forwards transpose across MUL with single element const. ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong --- .../luci/pass/src/ForwardTransposeOpPass.cpp | 86 +++++++++++ .../pass/src/ForwardTransposeOpPass.test.cpp | 136 +++++++++++++++++- 2 files changed, 219 insertions(+), 3 deletions(-) diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp index b9f7ae5a8e0..a8ae3b305e3 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/ForwardTransposeOpPass.h" +#include "helpers/NodeFiller.h" + #include #include #include @@ -150,6 +152,25 @@ bool check_perm(const CircleTranspose *t) return true; } +bool has_single_element(const luci::CircleConst *node) +{ + bool has_single_elem = false; + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + has_single_elem = node->size() == 1; + break; + default: + // NYI + break; + } + + if (has_single_elem) + assert(node->rank() == 0 or node->rank() == 1); // FIX_ME_UNLESS + + return has_single_elem; +} + #define RETURN_FALSE_UNLESS(COND) \ if (not(COND)) \ return false; @@ -158,8 +179,72 @@ bool check_perm(const CircleTranspose *t) class EBOWithConstPattern final : public CircleNodeMutableVisitor { private: + // TODO Rename this to has_commutative_pattern template bool has_pattern(CIRCLE_OP_PTR node) { + luci::CircleTranspose *transpose = nullptr; + luci::CircleConst *const_value = nullptr; + + RETURN_FALSE_UNLESS(luci::fill(&transpose, &const_value).with_commutative_args_of(node)); + + if (has_single_element(const_value)) + { + RETURN_FALSE_UNLESS(check_perm(transpose)); + auto new_transpose = create_cloned_transpose(transpose); + assert(new_transpose); // FIX_ME_UNLESS + + if (node->x() == const_value) + { + node->y(transpose->a()); + } + else + { + assert(node->y() == const_value); + node->x(transpose->a()); + } + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + else if (const_value->rank() == transpose->rank()) + { + // Only support rank 4 for now + RETURN_FALSE_UNLESS(check_rank_four(const_value)); + RETURN_FALSE_UNLESS(check_perm(transpose)); + + auto new_const = gen_new_const(transpose, const_value); + assert(new_const); // FIX_ME_UNLESS + + auto new_transpose = create_cloned_transpose(transpose); + assert(new_transpose); // FIX_ME_UNLESS + + // Reconnect network + if (node->x() == const_value) + { + node->x(new_const); + node->y(transpose->a()); + } + else + { + node->x(transpose->a()); + node->y(new_const); + } + + loco::replace(node).with(new_transpose); + new_transpose->a(node); + + // Do shape inference for this node again. + node->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; + } + +// TODO Remove unused code +#if 0 if (auto x = dynamic_cast(node->x())) { if (auto y = dynamic_cast(node->y())) @@ -213,6 +298,7 @@ class EBOWithConstPattern final : public CircleNodeMutableVisitor return true; } } +#endif return false; } diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp index 2d061c2a372..c3c502c98f0 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp @@ -39,13 +39,12 @@ template class TransposeBinaryOpGraphlet virtual ~TransposeBinaryOpGraphlet() = default; public: + // TODO Rename shape_in to shape_const void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 perm) { std::vector shape_in_v = shape_in; std::vector perm_v = perm; - assert(shape_in_v.size() == perm_v.size()); // FIX_CALLER_UNLESS - _perm = g->nodes()->create(); _const = g->nodes()->create(); _transpose = g->nodes()->create(); @@ -69,7 +68,7 @@ template class TransposeBinaryOpGraphlet _perm->at(i) = perm_v[i]; uint32_t elems = 1; - for (uint32_t i = 0; i < size; i++) + for (uint32_t i = 0; i < shape_in_v.size(); i++) elems *= shape_in_v[i]; _const->size(elems); @@ -155,6 +154,42 @@ class ForwardTransposeToMulGraph : public TestIOGraph, public TransposeMulGraphl } }; +class ForwardTransposeToScalarMulGraph : public TestIOGraph, public TransposeMulGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeMulGraphlet::init(g(), {}, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_const); + + output()->from(_binary); + } +}; + +class ForwardTransposeToSingleElemMulGraph : public TestIOGraph, public TransposeMulGraphlet +{ +public: + void init(const ShapeU32 shape_in, const ShapeU32 shape_out) + { + TestIOGraph::init(shape_in, shape_out); + TransposeMulGraphlet::init(g(), {1}, shape_out); + + // connect network + _transpose->a(input()); + _transpose->perm(_perm); + _binary->x(_transpose); + _binary->y(_const); + + output()->from(_binary); + } +}; + void run_phase(loco::Graph *g) { logo::Phase phase; @@ -196,6 +231,24 @@ class ForwardTransposeToMulGraphTest : public ::testing::Test ForwardTransposeToMulGraph _graph; }; +class ForwardTransposeToScalarMulGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToScalarMulGraph _graph; +}; + +class ForwardTransposeToSingleElemMulGraphTest : public ::testing::Test +{ +public: + void run_pass(void) { run_phase(_graph.g()); } + +protected: + ForwardTransposeToSingleElemMulGraph _graph; +}; + } // namespace TEST_F(ForwardTransposeToAddGraphTest, forward_add_xy) @@ -324,6 +377,61 @@ TEST_F(ForwardTransposeToMulGraphTest, forward_mul_yx) EXPECT_EQ(1, mul_const->dim(3).value()); } +TEST_F(ForwardTransposeToScalarMulGraphTest, forward_scalar_mul) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); + + auto mul_const = dynamic_cast(mul->y()); + EXPECT_NE(nullptr, mul_const); + EXPECT_EQ(0, mul_const->rank()); +} + +TEST_F(ForwardTransposeToSingleElemMulGraphTest, forward_single_elem_mul) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + run_pass(); + + auto transpose = dynamic_cast(_graph.output()->from()); + EXPECT_NE(nullptr, transpose); + EXPECT_EQ(4, transpose->rank()); + EXPECT_EQ(1, transpose->dim(0).value()); + EXPECT_EQ(1, transpose->dim(1).value()); + EXPECT_EQ(51, transpose->dim(2).value()); + EXPECT_EQ(64, transpose->dim(3).value()); + + auto mul = dynamic_cast(transpose->a()); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(4, mul->rank()); + EXPECT_EQ(1, mul->dim(0).value()); + EXPECT_EQ(64, mul->dim(1).value()); + EXPECT_EQ(51, mul->dim(2).value()); + EXPECT_EQ(1, mul->dim(3).value()); + + auto mul_const = dynamic_cast(mul->y()); + EXPECT_NE(nullptr, mul_const); + EXPECT_EQ(1, mul_const->rank()); + EXPECT_EQ(1, mul_const->dim(0).value()); +} + TEST_F(ForwardTransposeToAddGraphTest, forward_transpose_add_NEG) { _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); @@ -522,3 +630,25 @@ TEST_F(ForwardTransposeToAbsGraphNegTest, forward_transpose_abs_non_transpose_NE luci::ForwardTransposeOpPass pass; EXPECT_FALSE(pass.run(_graph.g())); } + +TEST_F(ForwardTransposeToScalarMulGraphTest, forward_transpose_smul_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + // Remove mul + _graph.output()->from(_graph.transpose()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +} + +TEST_F(ForwardTransposeToSingleElemMulGraphTest, forward_transpose_se_mul_NEG) +{ + _graph.init({1, 64, 51, 1}, {0, 3, 2, 1}); + + // Remove mul + _graph.output()->from(_graph.transpose()); + + luci::ForwardTransposeOpPass pass; + EXPECT_FALSE(pass.run(_graph.g())); +}