From 704bed9b224f13d79708a09bd6144a499f449e50 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Fri, 5 Apr 2024 18:09:20 +0900 Subject: [PATCH] [luci/pass] Revise FuseMulWithDivPass for Mul (#12849) This will revise FuseMulWithDivPass with new fuse_mul_with_div_to_mul method to fuse to Mul Op. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- compiler/luci/pass/src/FuseMulWithDivPass.cpp | 100 ++++++++++++++++++ .../luci/pass/src/FuseMulWithDivPass.test.cpp | 52 +++++++++ 2 files changed, 152 insertions(+) diff --git a/compiler/luci/pass/src/FuseMulWithDivPass.cpp b/compiler/luci/pass/src/FuseMulWithDivPass.cpp index f69316775c2..4fa72160099 100644 --- a/compiler/luci/pass/src/FuseMulWithDivPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithDivPass.cpp @@ -16,6 +16,8 @@ #include "luci/Pass/FuseMulWithDivPass.h" +#include "helpers/NodeFiller.h" + #include #include @@ -48,6 +50,28 @@ luci::CircleConst *create_div_const_with_new_value(luci::CircleConst *div_const, return new_div_const; } +// Return a new CircleConst with a new value +luci::CircleConst *create_mul_const_with_new_value(luci::CircleConst *mul_const, + luci::CircleConst *div_const, float new_value) +{ + assert(mul_const); // FIX_CALLER_UNLESS + assert(mul_const->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(mul_const->size() == 1); // FIX_CALLER_UNLESS + + auto new_mul_const = mul_const->graph()->nodes()->create(); + new_mul_const->dtype(loco::DataType::FLOAT32); + new_mul_const->rank(0); + new_mul_const->size(1); + new_mul_const->scalar() = new_value; + new_mul_const->shape_status(luci::ShapeStatus::VALID); + new_mul_const->name(mul_const->name() + ";" + div_const->name()); + + luci::add_origin(new_mul_const, luci::composite_origin( + {luci::get_origin(mul_const), luci::get_origin(div_const)})); + + return new_mul_const; +} + /** * Pass to fuse mul(one of the input is const scalar) and * div(numerator is const scalar) as div @@ -117,6 +141,79 @@ bool fuse_mul_with_div(luci::CircleDiv *div) return true; } +/** + * Pass to fuse mul(one of the input is const scalar) and + * div(numerator is const scalar) as mul + * + * BEFORE + * [CircleNode] [Scalar_Mul_Const] + * | | + * [CirlceMul, (x=CircleNode, y=Scalar_Mul_Const)] -------- + * | + * | [Scalar_Div_Const] + * | | + * [CircleDiv, (x=CirlceMul, y=Scalar_Div_Const)] ------ + * | + * [CircleNode] + * + * AFTER + * [CircleNode] + * | [Scalar_new_Mul_Const] + * | | + * [CircleMul, (x=CircleNode, y=Scalar_new_Mul_Const)] ------- + * | + * [CircleNode] + * + * where Scalar_new_Mul_Const = Scalar_Mul_Const / Scalar_Div_Const + * + **/ +bool fuse_mul_with_div_to_mul(luci::CircleDiv *div) +{ + if (div->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleMul *mul = nullptr; + luci::CircleConst *div_const = nullptr; + if (not luci::fill(&mul, &div_const).with_args_of(div)) + return false; + + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + if (div_const->dtype() != loco::DataType::FLOAT32) + return false; + // TODO support other shape + if (div_const->size() != 1) + return false; + + luci::CircleNode *mul_input = nullptr; + luci::CircleConst *mul_const = nullptr; + if (not luci::fill(&mul_input, &mul_const).with_commutative_args_of(mul)) + return false; + + if (mul_const->dtype() != loco::DataType::FLOAT32) + return false; + // TODO support other shape + if (mul_const->size() != 1) + return false; + + const auto mul_value = mul_const->at(0); + const auto div_value = div_const->at(0); + const auto new_value = mul_value / div_value; + auto new_mul_const = create_mul_const_with_new_value(mul_const, div_const, new_value); + + auto new_mul = div->graph()->nodes()->create(); + new_mul->fusedActivationFunction(luci::FusedActFunc::NONE); + new_mul->x(mul_input); + new_mul->y(new_mul_const); + new_mul->name(mul->name()); + luci::add_origin(new_mul, luci::composite_origin({luci::get_origin(div), luci::get_origin(mul)})); + + replace(div).with(new_mul); + + return true; +} + } // namespace bool FuseMulWithDivPass::run(loco::Graph *g) @@ -130,6 +227,9 @@ bool FuseMulWithDivPass::run(loco::Graph *g) if (fuse_mul_with_div(div)) changed = true; + + if (fuse_mul_with_div_to_mul(div)) + changed = true; } return changed; diff --git a/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp b/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp index 984e7f082be..67ad48e1d7a 100644 --- a/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithDivPass.test.cpp @@ -113,6 +113,26 @@ class FuseMulDivPatternTestGraph : public TestIOGraph, public PatternMulDivGraph } }; +class FuseMulDivToMulPatternTestGraph : public TestIOGraph, public PatternMulDivGraphlet +{ +public: + FuseMulDivToMulPatternTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 2, 3}, {1, 2, 3}); + PatternMulDivGraphlet::init(g()); + + _mul->x(input()); + _mul->y(_mul_const); + + _div->x(_mul); + _div->y(_div_const); + + output()->from(_div); + } +}; + } // namespace TEST(FuseMulWithDivPassTest, fus_mul_div_pattern) @@ -140,3 +160,35 @@ TEST(FuseMulWithDivPassTest, fuse_mul_div_NEG) EXPECT_FALSE(pass.run(g.g())); } + +TEST(FuseMulWithDivPassTest, fuse_mul_div_to_mul_pattern) +{ + FuseMulDivToMulPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + auto div = dynamic_cast(g.output()->from()); + EXPECT_NE(div, nullptr); + + EXPECT_TRUE(pass.run(g.g())); + + auto mul = dynamic_cast(g.output()->from()); + EXPECT_NE(mul, nullptr); +} + +TEST(FuseMulWithDivPassTest, fuse_mul_div_to_mul_NEG) +{ + FuseMulDivToMulPatternTestGraph g; + luci::FuseMulWithDivPass pass; + + g.init(); + + // Add CircleRelu operation between CircleMul and Div operations + auto relu = g.g()->nodes()->create(); + relu->name("relu"); + relu->features(g.mul()); + g.div()->x(relu); + + EXPECT_FALSE(pass.run(g.g())); +}