Skip to content

Commit

Permalink
[luci/service] Support RoPE operation (#14091)
Browse files Browse the repository at this point in the history
This commit supports RoPE for luci service

ONE-DCO-1.0-Signed-off-by: youngsik kim [email protected]
  • Loading branch information
ys44kim authored Sep 27, 2024
1 parent dffeb51 commit eb72364
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final;
// loco::TensorShape visit(const luci::CircleBCQGather *node) final;
// loco::TensorShape visit(const luci::CircleInstanceNorm *node) final;
// loco::TensorShape visit(const luci::CircleRoPE *node) final;

// Virtual
// loco::TensorShape visit(const luci::CircleCustomOut *node) final;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::DataType>
// loco::DataType visit(const luci::CircleBCQFullyConnected *node) final;
// loco::DataType visit(const luci::CircleBCQGather *node) final;
// loco::DataType visit(const luci::CircleInstanceNorm *node) final;
// loco::DataType visit(const luci::CircleRoPE *node) final;

// Virtual
// loco::DataType visit(const luci::CircleInput *node) final;
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/service/src/CircleCloneNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
luci::CircleNode *visit(const luci::CircleGRU *) final;
luci::CircleNode *visit(const luci::CircleRmsNorm *) final;
luci::CircleNode *visit(const luci::CircleRoPE *) final;

// NOTE CircleInput and CircleOutput are not handled here as these need
// link with graph I/O
Expand Down
7 changes: 7 additions & 0 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2205,6 +2205,13 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS
return loco::NodeShape{input_shape};
}

loco::NodeShape visit(const luci::CircleRoPE *node) final
{
auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();

return loco::NodeShape{input_shape};
}

// Virtual
loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }

Expand Down
5 changes: 5 additions & 0 deletions compiler/luci/service/src/CircleTypeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return luci::dtype_get(node->input());
}

loco::DataType visit(const luci::CircleRoPE *node) final
{
return luci::dtype_get(node->input());
}

// Virtual
loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }

Expand Down
32 changes: 32 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRoPE.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "CircleCloneNode.h"

namespace luci
{

luci::CircleNode *CloneNode::visit(const luci::CircleRoPE *node)
{
if (node->mode() == luci::RoPEMode::UNDEFINED)
return nullptr;

auto *cloned = _graph->nodes()->create<luci::CircleRoPE>();

return cloned;
}

} // namespace luci
46 changes: 46 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRoPE.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Service/CircleNodeClone.h"

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_RoPE)
{
auto g = loco::make_graph();
auto node_rp = g->nodes()->create<luci::CircleRoPE>();
node_rp->mode(luci::RoPEMode::GPT_NEOX);

auto gc = loco::make_graph();
auto cloned = luci::clone_node(node_rp, gc.get());
ASSERT_NE(nullptr, cloned);
ASSERT_EQ(gc.get(), cloned->graph());

auto cloned_rp = dynamic_cast<luci::CircleRoPE *>(cloned);
ASSERT_NE(nullptr, cloned_rp);
ASSERT_EQ(node_rp->mode(), cloned_rp->mode());
}

TEST(CloneNodeTest, clone_RoPE_NEG)
{
auto g = loco::make_graph();
auto node_rp = g->nodes()->create<luci::CircleRoPE>();
node_rp->mode(luci::RoPEMode::UNDEFINED);

auto gc = loco::make_graph();
auto cloned = luci::clone_node(node_rp, gc.get());
ASSERT_EQ(nullptr, cloned);
}

0 comments on commit eb72364

Please sign in to comment.