Skip to content

Commit

Permalink
[luci/pass] Add RmsNorm to QuantizePreChecker pass (Samsung#14039)
Browse files Browse the repository at this point in the history
This commit adds RmsNorm to QuantizePreChecker pass.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Sep 23, 2024
1 parent 00c2e01 commit 8a5da63
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/pass/src/QuantizePreCheckerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct ConstInputChecker final : public luci::CircleNodeMutableVisitor<void>
CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleDepthwiseConv2D, filter, bias)
CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleFullyConnected, weights, bias)
CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleInstanceNorm, gamma, beta)
CHECK_NODE_WITH_TWO_INPUT_CONST(luci::CircleRmsNorm, gamma, beta)

// Ops that receive three const nodes as an inputs
CHECK_NODE_WITH_THREE_INPUT_CONST(luci::CircleTransposeConv, inputSizes, filter, bias)
Expand Down
62 changes: 62 additions & 0 deletions compiler/luci/pass/src/QuantizePreCheckerPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,49 @@ class SimpleInstanceNormGraph
luci::CircleOutput *output = nullptr;
};

class SimpleRmsNormGraph
{
public:
SimpleRmsNormGraph(bool make_valid)
{
rms_norm_node = g.nodes()->create<luci::CircleRmsNorm>();
input_1 = g.nodes()->create<luci::CircleInput>();
gamma = g.nodes()->create<luci::CircleConst>();

rms_norm_node->input(input_1);
rms_norm_node->gamma(gamma);

if (make_valid)
{
beta = g.nodes()->create<luci::CircleConst>();
rms_norm_node->beta(beta);
}
else
{
input_2 = g.nodes()->create<luci::CircleInput>();
rms_norm_node->beta(input_2);
}

output = g.nodes()->create<luci::CircleOutput>();

auto graph_output = g.outputs()->create();
output->index(graph_output->index());

output->from(rms_norm_node);
}

public:
loco::Graph g;

private:
luci::CircleRmsNorm *rms_norm_node = nullptr;
luci::CircleInput *input_1 = nullptr;
luci::CircleInput *input_2 = nullptr;
luci::CircleConst *gamma = nullptr;
luci::CircleConst *beta = nullptr;
luci::CircleOutput *output = nullptr;
};

class SimpleTransposeConvGraph
{
public:
Expand Down Expand Up @@ -363,6 +406,25 @@ TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
}

// Test RmsNorm
TEST(QuantizePreCheckerPassTest, rms_norm)
{
SimpleRmsNormGraph valid_graph(true);

luci::QuantizePreCheckerPass checker{};

EXPECT_NO_THROW(checker.run(&valid_graph.g));
}

TEST(QuantizePreCheckerPassTest, rms_norm_NEG)
{
SimpleRmsNormGraph invalid_graph(false);

luci::QuantizePreCheckerPass checker{};

EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
}

// Test TransposeConv
TEST(QuantizePreCheckerPassTest, transpose_conv)
{
Expand Down

0 comments on commit 8a5da63

Please sign in to comment.