diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp index a1ff82f83e2..82bc9c7321f 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -95,6 +96,8 @@ luci::CircleReshape *create_reshape(luci::CircleFullyConnected *node) reshape->shape(shape_const); + luci::copy_quantparam(node, reshape); + return reshape; } @@ -165,9 +168,6 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) x = loco::must_cast(fc->input()); } - if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32) - return false; - auto bc = dynamic_cast(fc->bias()); // NOTE bias can be empty as CircleOutputExclude type // NOTE we can only handle bias as FLOAT32 type as of now @@ -185,6 +185,8 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc) matmul->name(name); matmul->dtype(fc->dtype()); + luci::copy_quantparam(fc, matmul); + luci::add_origin(matmul, luci::get_origin(fc)); auto reshape = create_reshape(fc); diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp index 194893f0124..50fdb25885e 100644 --- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp +++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp @@ -84,6 +84,56 @@ struct FCGraphlet luci::CircleInput *_y = nullptr; }; +struct S16FCGraphlet +{ +public: + S16FCGraphlet() = default; + virtual ~S16FCGraphlet() = default; + + void init(loco::Graph *g, const ShapeU32 r_shape) + { + _tr_x = g->nodes()->create(); + _tr_x->a(_x); + std::vector tr_x_val = {1, 0}; + _tr_x->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_x_val)); + _tr_x->dtype(loco::DataType::S16); + + _tr_y = g->nodes()->create(); + _tr_y->a(_y); + std::vector tr_y_val = {1, 0}; + _tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_y_val)); + _tr_y->dtype(loco::DataType::S16); + + _fc = g->nodes()->create(); + _fc->input(_tr_x); + _fc->weights(_tr_y); + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + _fc->dtype(loco::DataType::S16); + _fc->shape(r_shape); + + auto no_bias = g->nodes()->create(); + _fc->bias(no_bias); + _fc->name("fc"); + + auto qparam = std::make_unique(); + { + qparam->scale = {1.0}; + qparam->zerop = {0}; + } + _fc->quantparam(std::move(qparam)); + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleTranspose *_tr_x = nullptr; + luci::CircleTranspose *_tr_y = nullptr; + luci::CircleInput *_x = nullptr; + luci::CircleInput *_y = nullptr; +}; + struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet { FCGraph() = default; @@ -99,6 +149,19 @@ struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphl } }; +struct S16FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public S16FCGraphlet +{ + void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape) + { + TestIsGraphlet<2>::init(g(), {x_shape, y_shape}); + TestOGraphlet::init(g(), r_shape); + _x = input(0); + _y = input(1); + S16FCGraphlet::init(g(), r_shape); + output()->from(_fc); + } +}; + class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test { public: @@ -106,6 +169,13 @@ class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test luci::ReplaceNonConstFCWithBatchMatMulPass pass; }; +class ReplaceNonConstS16FCWithBatchMatMulPassTest : public ::testing::Test +{ +public: + S16FCGraph g; + luci::ReplaceNonConstFCWithBatchMatMulPass pass; +}; + } // namespace TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test) @@ -130,6 +200,22 @@ TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test) EXPECT_NE(nullptr, mm); } +TEST_F(ReplaceNonConstS16FCWithBatchMatMulPassTest, s16_test) +{ + g.init({2, 3}, {2, 3}, {2, 2}); + + auto ret = pass.run(g.g()); + EXPECT_EQ(true, ret); + + auto res = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, res); + + auto qparam = res->quantparam(); + EXPECT_NE(nullptr, qparam); + EXPECT_EQ(1.0, qparam->scale[0]); + EXPECT_EQ(0, qparam->zerop[0]); +} + TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG) { loco::Graph g;