Skip to content

Commit

Permalink
[luci/partition] Support RoPE operation (#14101)
Browse files Browse the repository at this point in the history
This commit supports RoPE for luci partition

ONE-DCO-1.0-Signed-off-by: youngsik kim [email protected]
  • Loading branch information
ys44kim authored Sep 27, 2024
1 parent a1e9884 commit 217ae0d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/partition/include/luci/ConnectNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class ConnectNode final : public luci::CircleNodeVisitor<void>
void visit(const luci::CircleGRU *) final;
void visit(const luci::CircleInstanceNorm *) final;
void visit(const luci::CircleRmsNorm *) final;
void visit(const luci::CircleRoPE *) final;

// NOTE CircleInput and CircleOutput are not handled here as these need
// link with graph I/O
Expand Down
42 changes: 42 additions & 0 deletions compiler/luci/partition/src/Nodes/CircleRoPE.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/ConnectNode.h"

namespace
{

void connect(luci::ConnectNode *cn, const luci::CircleRoPE *node)
{
auto *cloned = loco::must_cast<luci::CircleRoPE *>(cn->find_clone(node));

luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
luci::CircleNode *sin_table = loco::must_cast<luci::CircleNode *>(node->sin_table());
luci::CircleNode *cos_table = loco::must_cast<luci::CircleNode *>(node->cos_table());

cloned->input(cn->find_clone(input));
cloned->sin_table(cn->find_clone(sin_table));
cloned->cos_table(cn->find_clone(cos_table));
}

} // namespace

namespace luci
{

void ConnectNode::visit(const luci::CircleRoPE *node) { connect(this, node); }

} // namespace luci
85 changes: 85 additions & 0 deletions compiler/luci/partition/src/Nodes/CircleRoPE.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/ConnectNode.h"

#include "ConnectNode.test.h"

#include <luci/Service/CircleNodeClone.h>

#include <gtest/gtest.h>

namespace
{

using namespace luci::test;

class NodeGraphlet : public NodeGraphletT<luci::CircleRoPE>
{
public:
NodeGraphlet() = default;

public:
void init(loco::Graph *g) override
{
NodeGraphletT<luci::CircleRoPE>::init(g);

_node->mode(luci::RoPEMode::GPT_NEOX);
}
};

class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
{
public:
TestNodeGraph() = default;

public:
void init(const ShapeU32 shape)
{
TestIsOGraph<3>::init({shape, shape, shape}, shape);
NodeGraphlet::init(g());

node()->input(input(0));
node()->sin_table(input(1));
node()->cos_table(input(2));

output()->from(node());
}
};

} // namespace

TEST(ConnectNodeTest, connect_RoPE)
{
TestNodeGraph tng;
tng.init({2, 3});

ConnectionTestHelper cth;
cth.prepare_inputs(&tng);

auto *node = tng.node();
ASSERT_NO_THROW(loco::must_cast<luci::CircleRoPE *>(node));

auto *clone = luci::clone_node(node, cth.graph_clone());
ASSERT_NO_THROW(loco::must_cast<luci::CircleRoPE *>(clone));

cth.clone_connect(node, clone);

ASSERT_EQ(3, clone->arity());
ASSERT_EQ(cth.inputs(0), clone->arg(0));
ASSERT_EQ(cth.inputs(1), clone->arg(1));
ASSERT_EQ(cth.inputs(2), clone->arg(2));
}

0 comments on commit 217ae0d

Please sign in to comment.