Skip to content

Commit

Permalink
[luci/pass] Introduce SubstituteExpandDimsToReshapePass
Browse files Browse the repository at this point in the history
Let's introduces a new luci pass to substitute CircleExpandDims to CircleReshape.

Signed-off-by: Dayoung Lee <[email protected]>
  • Loading branch information
dayo09 committed Oct 4, 2024
1 parent 8753418 commit b093212
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_SUBSTITUTE_EXPAND_DIMS_TO_RESHAPE_PASS_H__
#define __LUCI_SUBSTITUTE_EXPAND_DIMS_TO_RESHAPE_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to substitute ExpandDims to single reshape node.
*/
struct SubstituteExpandDimsToReshapePass final : public logo::Pass
{
const char *name(void) const final { return "luci::SubstituteExpandDimsToReshapePass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_SUBSTITUTE_EXPAND_DIMS_TO_RESHAPE_PASS_H__
145 changes: 145 additions & 0 deletions compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/SubstituteExpandDimsToReshapePass.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

#include <bitset>
#include <vector>

/**
* @brief Convert expand_dims op to reshape op
* @example
* input.shape = [2,3,4]
* expand_dims(input, axis=1)
*
* can be converted to
*
* reshape(input, [2,1,3,4])
*/
namespace
{

int32_t unknown_dim_count(luci::CircleNode *node)
{
int32_t count = 0;

for (uint32_t i = 0; i < node->rank(); ++i)
if (!node->dim(i).known())
++count;

return count;
}

bool substitute_expand_dims_to_reshape(luci::CircleNode *node)
{
auto target_node = dynamic_cast<luci::CircleExpandDims *>(node);
if (target_node == nullptr)
return false;
if (target_node->shape_status() != luci::ShapeStatus::VALID) //
return false;
auto input_node = loco::must_cast<luci::CircleNode *>(target_node->input());
if (input_node->rank() <= 0)
return false;
if (input_node->shape_status() != luci::ShapeStatus::VALID) //
return false;
auto axis_node = loco::must_cast<luci::CircleConst *>(target_node->axis());
if (axis_node == nullptr)
return false;

auto axis = axis_node->at<loco::DataType::S32>(0);
if (axis < 0)
axis = axis + static_cast<int32_t>(input_node->rank()) + 1;

auto name = node->name();
assert(name.length() > 0);

auto graph = target_node->graph();
auto reshape_node = graph->nodes()->create<luci::CircleReshape>();
reshape_node->tensor(input_node);
reshape_node->name(name + "/Reshape");
luci::add_origin(reshape_node, luci::get_origin(node));

auto const_node = graph->nodes()->create<luci::CircleConst>();
const_node->dtype(loco::DataType::S32);
const_node->size<loco::DataType::S32>(input_node->rank() + 1);
const_node->shape_status(luci::ShapeStatus::VALID);
const_node->rank(1);
const_node->dim(0).set(input_node->rank() + 1);
for (int32_t i = 0; i < static_cast<int32_t>(input_node->rank()) + 1; i++)
{
if (i == axis)
{
const_node->at<loco::DataType::S32>(i) = 1;
}
else if (i < axis)
{
const_node->at<loco::DataType::S32>(i) =
input_node->dim(i).known() ? input_node->dim(i).value() : -1;
}
else
{
const_node->at<loco::DataType::S32>(i) =
input_node->dim(i - 1).known() ? input_node->dim(i - 1).value() : -1;
}
}
const_node->name(name + "/Reshape/shape");
reshape_node->shape(const_node);
replace(target_node).with(reshape_node);
return true;
}

} // namespace

namespace luci
{

/**
* BEFORE
* |
* [CircleNode] [CircleConst]
* \ /
* [CircleExpandDims]
* |
* [CircleNode]
* |
*
* AFTER
* |
* [CircleNode] [CircleConst]
* \ /
* [CircleReshape]
* |
* [CircleNode]
* |
*/
bool SubstituteExpandDimsToReshapePass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
if (unknown_dim_count(circle_node) == 0 && substitute_expand_dims_to_reshape(circle_node))
{
changed = true;
}
}
return changed;
}

} // namespace luci
201 changes: 201 additions & 0 deletions compiler/luci/pass/src/SubstituteExpandDimsToReshapePass.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "luci/Pass/SubstituteExpandDimsToReshapePass.h"
#include "luci/Pass/CircleShapeInferencePass.h"

#include <luci/IR/CircleNodes.h>

#include <gtest/gtest.h>

namespace
{

using uilist = std::initializer_list<uint32_t>;
using ilist = std::initializer_list<int32_t>;

class PassTestGraph
{
public:
PassTestGraph() = default;

public:
void init(const uilist shape_in, const uilist shape_out, const int val)
{
_graph_input = _g.inputs()->create();
_graph_output = _g.outputs()->create();

_input = _g.nodes()->create<luci::CircleInput>();
_input->shape(shape_in);
_input->shape_status(luci::ShapeStatus::VALID);
_input->name("input");

_output = _g.nodes()->create<luci::CircleOutput>();
_output->shape(shape_out);
_output->shape_status(luci::ShapeStatus::VALID);
_output->name("output");

_const = _g.nodes()->create<luci::CircleConst>();
_const->dtype(loco::DataType::S32);
_const->size<loco::DataType::S32>(1);
_const->at<loco::DataType::S32>(0) = val;
_const->name("const");

_input->index(_graph_input->index());
_output->index(_graph_output->index());

auto input_shape = std::make_unique<loco::TensorShape>();
set(input_shape.get(), shape_in);
_graph_input->shape(std::move(input_shape));

auto output_shape = std::make_unique<loco::TensorShape>();
set(output_shape.get(), shape_out);
_graph_output->shape(std::move(output_shape));
}

protected:
void set(loco::TensorShape *shape, const uilist &values)
{
uint32_t r = 0;
shape->rank(values.size());
for (auto v : values)
shape->dim(r++).set(v);
}

public:
loco::Graph *g(void) { return &_g; }
luci::CircleOutput *output(void) { return _output; }

protected:
loco::Graph _g;
loco::GraphInput *_graph_input = nullptr;
loco::GraphOutput *_graph_output = nullptr;
luci::CircleInput *_input = nullptr;
luci::CircleOutput *_output = nullptr;
luci::CircleConst *_const = nullptr;
};

class SubstituteExpandDimsToReshapeGraph : public PassTestGraph
{
public:
SubstituteExpandDimsToReshapeGraph() = default;

public:
void init(const uilist shape_in, const uilist shape_out, const int axis)
{
PassTestGraph::init(shape_in, shape_out, axis);

_expand_dims = _g.nodes()->create<luci::CircleExpandDims>();
_expand_dims->input(_input);
_expand_dims->axis(_const);
_expand_dims->name("expand_dims");

_output->from(_expand_dims);
}

protected:
luci::CircleExpandDims *_expand_dims = nullptr;
};

class SubstituteExpandDimsToReshapeTest : public ::testing::Test
{
public:
SubstituteExpandDimsToReshapeTest() = default;

void run_pass(void)
{
while (_shapeinf.run(_graph.g()) || _pass.run(_graph.g()))
;
}

protected:
SubstituteExpandDimsToReshapeGraph _graph;
luci::SubstituteExpandDimsToReshapePass _pass;
luci::CircleShapeInferencePass _shapeinf;
};

} // namespace

TEST(SubstituteExpandDimsToReshapePassTest, name)
{
luci::SubstituteExpandDimsToReshapePass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(SubstituteExpandDimsToReshapeTest, simple_with_expand_dims_1)
{
_graph.init({2, 16}, {2, 1, 16}, 1);

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
auto expand_dims = dynamic_cast<luci::CircleExpandDims *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);
ASSERT_EQ(nullptr, expand_dims);
auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
ASSERT_EQ(3, reshape_shape->size<loco::DataType::S32>());
ASSERT_EQ(2, reshape_shape->at<loco::DataType::S32>(0));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(1));
ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(2));
}

TEST_F(SubstituteExpandDimsToReshapeTest, simple_with_expand_dims_M1)
{
_graph.init({2, 3, 4}, {2, 3, 4, 1}, -1);

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
auto expand_dims = dynamic_cast<luci::CircleExpandDims *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);
ASSERT_EQ(nullptr, expand_dims);
auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
ASSERT_EQ(4, reshape_shape->size<loco::DataType::S32>());
ASSERT_EQ(2, reshape_shape->at<loco::DataType::S32>(0));
ASSERT_EQ(3, reshape_shape->at<loco::DataType::S32>(1));
ASSERT_EQ(4, reshape_shape->at<loco::DataType::S32>(2));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(3));
}

TEST_F(SubstituteExpandDimsToReshapeTest, simple_with_expand_dims_2)
{
_graph.init({16, 3, 1}, {16, 3, 1, 1}, 2);

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
auto expand_dims = dynamic_cast<luci::CircleExpandDims *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);
ASSERT_EQ(nullptr, expand_dims);
auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
ASSERT_EQ(4, reshape_shape->size<loco::DataType::S32>());
ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0));
ASSERT_EQ(3, reshape_shape->at<loco::DataType::S32>(1));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(2));
ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(3));
}

TEST_F(SubstituteExpandDimsToReshapeTest, nothing_to_expand_dims)
{
_graph.init({2, 16, 16, 3}, {2, 16, 16, 3}, {});

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
auto expand_dims = dynamic_cast<luci::CircleExpandDims *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);
ASSERT_EQ(nullptr, expand_dims);
}

0 comments on commit b093212

Please sign in to comment.