From b5d7ed0509ae5359253f41aed258c455c902d660 Mon Sep 17 00:00:00 2001 From: Balyshev Artem <43214667+BalyshevArtem@users.noreply.github.com> Date: Mon, 30 Oct 2023 03:22:25 +0300 Subject: [PATCH] [luci/pass] Introduce FuseHorizontalFCLayers pass (#11787) This commit introduces FuseHorizontalFCLayers pass. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../Pass/FuseHorizontalFullyConnectedPass.h | 34 +++ compiler/luci/pass/src/CircleOptimizer.cpp | 5 + .../src/FuseHorizontalFullyConnectedPass.cpp | 218 ++++++++++++++++++ .../FuseHorizontalFullyConnectedPass.test.cpp | 185 +++++++++++++++ 5 files changed, 443 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h create mode 100644 compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp create mode 100644 compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 4733737034d..436301b14ec 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -41,6 +41,7 @@ class CircleOptimizer final FuseBatchNormWithTConv, FuseSliceWithTConv, FuseBCQ, + FuseHorizontalFullyConnected, FuseInstanceNorm, FuseMeanWithMean, FuseTransposeWithMean, diff --git a/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h new file mode 100644 index 00000000000..49729c081e5 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseHorizontalFullyConnectedPass.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ +#define __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ + +#include + +namespace luci +{ + +struct FuseHorizontalFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseHorizontalFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_HORIZONTAL_FULLY_CONNECTED_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 54a51d47c96..da34d00e37b 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -41,6 +41,7 @@ #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" #include "luci/Pass/RemoveDuplicateConstPass.h" @@ -306,6 +307,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseHorizontalFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseTransposeWithMean)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp new file mode 100644 index 00000000000..3aa37256af8 --- /dev/null +++ b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" + +#include +#include +#include + +namespace luci +{ + +namespace +{ + +bool check_type_and_shape_equality(const CircleNode *left, const CircleNode *right) +{ + if (left->dtype() != right->dtype()) + return false; + + if (left->rank() != right->rank()) + return false; + + for (uint32_t i = 0; i < left->rank(); ++i) + { + if (left->dim(i).value() != right->dim(i).value()) + return false; + } + + return true; +} + +// Add right const to left const (left is updated) +template void sum_const_values(CircleConst *left, const CircleConst *right) +{ + assert(check_type_and_shape_equality(left, right)); // FIX CALLER UNLESS + const auto size = left->template size(); + + for (uint32_t i = 0; i < size; ++i) + { + left->template at(i) += right->template at(i); + } +} + +bool fuse_horizontal_fc_nodes(CircleAdd *add_node) +{ + // Let's check left and right FC nodes + auto left_fc_node = dynamic_cast(add_node->x()); + auto right_fc_node = dynamic_cast(add_node->y()); + + if (left_fc_node == nullptr or right_fc_node == nullptr) + return false; + + if (not check_type_and_shape_equality(left_fc_node, right_fc_node)) + return false; + + if (left_fc_node->fusedActivationFunction() != FusedActFunc::NONE) + return false; + + if (right_fc_node->fusedActivationFunction() != FusedActFunc::NONE) + return false; + + // Let's check that FC nodes have the same input + if (left_fc_node->input() != right_fc_node->input()) + return false; + + // Lets check left and right FC weights: type and shape + auto left_fc_weights = dynamic_cast(left_fc_node->weights()); + auto right_fc_weights = dynamic_cast(right_fc_node->weights()); + + if (left_fc_weights == nullptr or right_fc_weights == nullptr) + return false; + + if (not check_type_and_shape_equality(left_fc_weights, right_fc_weights)) + return false; + + // Lets check left and right FC bias: type and shape + auto left_fc_bias = dynamic_cast(left_fc_node->bias()); + auto right_fc_bias = dynamic_cast(right_fc_node->bias()); + + // Support only if both biases are const, or both are non-const + // TODO Support the case that one FC has a const bias and another FC has no bias. + if ((left_fc_bias == nullptr and right_fc_bias != nullptr) or + (left_fc_bias != nullptr and right_fc_bias == nullptr)) + { + return false; + } + + // Both left/right bias are const. Check dtype/shape. + if (left_fc_bias != nullptr and not check_type_and_shape_equality(left_fc_bias, right_fc_bias)) + return false; + + // Both left/right bias are non-const. Check left/right fc has no bias. + if (left_fc_bias == nullptr) + { + auto left_no_bias = dynamic_cast(left_fc_node->bias()); + auto right_no_bias = dynamic_cast(right_fc_node->bias()); + if (not left_no_bias or not right_no_bias) + return false; + } + + if (left_fc_weights->dtype() != loco::DataType::FLOAT32) + return false; + + if (left_fc_bias != nullptr) + { + if (left_fc_bias->dtype() != loco::DataType::FLOAT32) + return false; + } + + // Lets create fused FC weights and bias + auto fused_fc_weights = clone(left_fc_weights); + add_origin(fused_fc_weights, + composite_origin({get_origin(left_fc_weights), get_origin(right_fc_weights)})); + + CircleConst *fused_fc_bias = nullptr; + if (left_fc_bias != nullptr) + { + fused_fc_bias = clone(left_fc_bias); + add_origin(fused_fc_bias, + composite_origin({get_origin(left_fc_bias), get_origin(right_fc_bias)})); + } + + assert(fused_fc_weights->dtype() == loco::DataType::FLOAT32); + sum_const_values(fused_fc_weights, right_fc_weights); + + if (fused_fc_bias != nullptr) + { + assert(fused_fc_bias->dtype() == loco::DataType::FLOAT32); + sum_const_values(fused_fc_bias, right_fc_bias); + } + + // Create fused FC node + auto graph = left_fc_node->graph(); + auto fused_fc_node = graph->nodes()->create(); + fused_fc_node->input(left_fc_node->input()); + fused_fc_node->weights(fused_fc_weights); + if (fused_fc_bias) + { + fused_fc_node->bias(fused_fc_bias); + } + else + { + assert(nullptr != dynamic_cast(left_fc_node->bias())); // FIX ME UNLESS + fused_fc_node->bias(left_fc_node->bias()); + } + + fused_fc_node->fusedActivationFunction(add_node->fusedActivationFunction()); + fused_fc_node->name(left_fc_node->name() + "_" + right_fc_node->name() + "_fused"); + + add_origin(fused_fc_node, composite_origin({get_origin(left_fc_node), get_origin(right_fc_node), + get_origin(add_node)})); + + replace(add_node).with(fused_fc_node); + + return true; +} + +} // namespace + +/** + * @brief Class to fuse horizontal FC layers + * + * Before + * + * +---- [In] ----+ + * | | + * V V + * fc1 (w1, b1) fc2 (w2, b2) + * | | + * | | + * +---> add <----+ + * | + * V + * [Out] + * + * After + * + * [In] + * | + * V + * fc3 (w1+w2, b1+b2) + * | + * V + * [Out] + * + * Shape/dtype of fc1, fc2, and fc3 should be the same. + */ +bool FuseHorizontalFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add_node = dynamic_cast(node); + if (not add_node) + continue; + + if (fuse_horizontal_fc_nodes(add_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp new file mode 100644 index 00000000000..3dba7f89afd --- /dev/null +++ b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.test.cpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseHorizontalFullyConnectedPass.h" +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +/* + * Before + * + * +---- [In] ----+ + * | | + * V V + * fc1 (w1, b1) fc2 (w2, b2) + * | | + * | | + * +---> add <----+ + * | + * V + * [Out] + * + * After + * + * [In] + * | + * V + * fc3 (w1+w2, b1+b2) + * | + * V + * [Out] + */ +class FuseHorizontalFCLayersTestGraph : public TestIOGraph +{ +public: + FuseHorizontalFCLayersTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 10}, {1, 10}); + + _left_fc = g()->nodes()->create(); + _right_fc = g()->nodes()->create(); + _left_weight = g()->nodes()->create(); + _right_weight = g()->nodes()->create(); + + _left_fc->name("left FC"); + _right_fc->name("right FC"); + _left_weight->name("left weight"); + _right_weight->name("right weight"); + + _left_fc->dtype(loco::DataType::FLOAT32); + _right_fc->dtype(loco::DataType::FLOAT32); + + _left_fc->shape_status(luci::ShapeStatus::VALID); + _right_fc->shape_status(luci::ShapeStatus::VALID); + + _left_fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _right_fc->fusedActivationFunction(luci::FusedActFunc::NONE); + + _left_fc->rank(2); + _right_fc->rank(2); + + _right_fc->dim(0) = 1; + _right_fc->dim(1) = 10; + + _left_fc->dim(0) = 1; + _left_fc->dim(1) = 10; + + _left_weight->rank(2); + _left_weight->dtype(loco::DataType::FLOAT32); + _left_weight->size(5 * 10); + for (uint32_t i = 0; i < 5 * 10; ++i) + { + _left_weight->at(0) = 1.0f; + } + _left_weight->dim(0) = 5; + _left_weight->dim(1) = 10; + _left_weight->shape_status(luci::ShapeStatus::VALID); + + _right_weight->rank(2); + _right_weight->dtype(loco::DataType::FLOAT32); + _right_weight->size(5 * 10); + for (uint32_t i = 0; i < 5 * 10; ++i) + { + _right_weight->at(0) = 2.0f; + } + _right_weight->dim(0) = 5; + _right_weight->dim(1) = 10; + _right_weight->shape_status(luci::ShapeStatus::VALID); + + const auto left_output_exclude = g()->nodes()->create(); + const auto right_output_exclude = g()->nodes()->create(); + + _left_fc->input(input()); + _left_fc->weights(_left_weight); + _left_fc->bias(left_output_exclude); + _right_fc->input(input()); + _right_fc->weights(_right_weight); + _right_fc->bias(right_output_exclude); + + _add = g()->nodes()->create(); + _add->dtype(loco::DataType::FLOAT32); + _add->rank(2); + _add->dim(0) = 1; + _add->dim(1) = 5; + _add->x(_left_fc); + _add->y(_right_fc); + _add->shape_status(luci::ShapeStatus::VALID); + + output()->from(_add); + } + + luci::CircleFullyConnected *get_left_fc() { return _left_fc; } + +private: + luci::CircleFullyConnected *_left_fc = nullptr; + luci::CircleConst *_left_weight = nullptr; + luci::CircleFullyConnected *_right_fc = nullptr; + luci::CircleConst *_right_weight = nullptr; + luci::CircleAdd *_add = nullptr; +}; + +} // namespace + +TEST(FuseHorizontalFCLayersPassTest, name) +{ + luci::FuseHorizontalFullyConnectedPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseHorizontalFCLayersPassTest, fuse_horizontal_nodes) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseHorizontalFCLayersPassTest, fuse_horizontal_nodes_NEG) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + g.get_left_fc()->fusedActivationFunction(luci::FusedActFunc::RELU6); + + EXPECT_FALSE(pass.run(g.g())); +} + +TEST(FuseHorizontalFCLayersPassTest, wrong_dtype_NEG) +{ + FuseHorizontalFCLayersTestGraph g; + luci::FuseHorizontalFullyConnectedPass pass; + + g.init(); + + g.get_left_fc()->dtype(loco::DataType::S32); + + EXPECT_FALSE(pass.run(g.g())); +}