diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp index 4b3b7e33095..1eea4f66d5d 100644 --- a/compiler/luci/pass/src/QuantizePreCheckerPass.cpp +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.cpp @@ -84,6 +84,7 @@ struct ConstInputChecker final : public luci::CircleNodeMutableVisitor CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias) CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias) CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta) + CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleRmsNorm, gamma, beta) // Ops that receive three const nodes as an inputs CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias) diff --git a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp index 8f6a96f3330..3f6295f4a2e 100644 --- a/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp +++ b/compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp @@ -192,6 +192,49 @@ class SimpleInstanceNormGraph luci::CircleOutput *output = nullptr; }; +class SimpleRmsNormGraph +{ +public: + SimpleRmsNormGraph(bool make_valid) + { + rms_norm_node = g.nodes()->create(); + input_1 = g.nodes()->create(); + gamma = g.nodes()->create(); + + rms_norm_node->input(input_1); + rms_norm_node->gamma(gamma); + + if (make_valid) + { + beta = g.nodes()->create(); + rms_norm_node->beta(beta); + } + else + { + input_2 = g.nodes()->create(); + rms_norm_node->beta(input_2); + } + + output = g.nodes()->create(); + + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + output->from(rms_norm_node); + } + +public: + loco::Graph g; + +private: + luci::CircleRmsNorm *rms_norm_node = nullptr; + luci::CircleInput *input_1 = nullptr; + luci::CircleInput *input_2 = nullptr; + luci::CircleConst *gamma = nullptr; + luci::CircleConst *beta = nullptr; + luci::CircleOutput *output = nullptr; +}; + class SimpleTransposeConvGraph { public: @@ -363,6 +406,25 @@ TEST(QuantizePreCheckerPassTest, instance_norm_NEG) EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); } +// Test RmsNorm +TEST(QuantizePreCheckerPassTest, rms_norm) +{ + SimpleRmsNormGraph valid_graph(true); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_NO_THROW(checker.run(&valid_graph.g)); +} + +TEST(QuantizePreCheckerPassTest, rms_norm_NEG) +{ + SimpleRmsNormGraph invalid_graph(false); + + luci::QuantizePreCheckerPass checker{}; + + EXPECT_ANY_THROW(checker.run(&invalid_graph.g)); +} + // Test TransposeConv TEST(QuantizePreCheckerPassTest, transpose_conv) {