Skip to content

Commit

Permalink
[luci/pass] Support rank expansion case in RmUnnTransNetPass (#14167)
Browse files Browse the repository at this point in the history
This PR expands RemoveUnnecessaryTransposeNetPass to support rank expansion case.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>
Co-authored-by: Jang Jiseob <[email protected]>
  • Loading branch information
zetwhite and ragmani authored Oct 9, 2024
1 parent 9adb622 commit fd49fc8
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 12 deletions.
69 changes: 57 additions & 12 deletions compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ bool extract_const(const luci::CircleConst *const_node, std::vector<int32_t> &va
return true;
};

/**
* @brief For a given a perm P, this function returns Q (the inverse of P),
* which satisfies 'Q(P(x))==x'
*
* @example If perm={1, 3, 0, 2} is given, it returns {2, 0, 3, 1}.
*
* x: [0, 1, 2, 3]
* | -----> apply perm {1, 3, 0, 2}
* P(x): [1, 3, 0, 2]
* | -----> apply perm {2, 0, 3, 1}
* Q(P(x)) [0, 1, 2, 3]
*
*/
std::vector<int32_t> inverse_perm(const std::vector<int32_t> &perm)
{
std::vector<int32_t> inv_perm(perm.size());
for (auto i = 0u; i < perm.size(); ++i)
{
inv_perm[perm[i]] = i;
}
return inv_perm;
};

struct TagDim final
{
TagDim(int32_t v) : value(v) {}
Expand Down Expand Up @@ -106,8 +129,10 @@ class TaggedShapeAnalyzer final

std::vector<int32_t> _in_shape_v;
std::vector<int32_t> _front_perm_v;
std::vector<int32_t> _front_shape_v;
std::vector<int32_t> _mid_shape_v;
std::vector<int32_t> _back_perm_v;
std::vector<int32_t> _out_shape_v;

const uint8_t START_TAG = 0;
TagShape _shape;
Expand Down Expand Up @@ -268,9 +293,8 @@ bool TaggedShapeAnalyzer::verify_tag() const
* @brief Initialize the class members and check under conditions
*
* Condtiions that have to be met for analyzer
* c1: input rank >= output rank
* c2: The 'perm' of tranpose should be a CircleConst* type
* c3: The input shape and the reshape node's shape should be known
* c3: All extracted shapes (named as '*_shape_v' in member variable) should be known
*
* @return True, if all conditions are satisfied and class members are initialized successfully
* False, otherwise
Expand All @@ -284,9 +308,6 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,
_mid_reshape = mid_reshape;
_back_transpose = back_transpose;

// check c1
CHECK_OR_FALSE(_in->rank() >= _back_transpose->rank());

const auto front_perm = dynamic_cast<luci::CircleConst *>(_front_transpose->perm());
const auto back_perm = dynamic_cast<luci::CircleConst *>(_back_transpose->perm());

Expand All @@ -296,8 +317,10 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,

CHECK_OR_FALSE(extract_shape(_in, _in_shape_v));
CHECK_OR_FALSE(extract_const(front_perm, _front_perm_v));
CHECK_OR_FALSE(extract_shape(_front_transpose, _front_shape_v));
CHECK_OR_FALSE(extract_shape(_mid_reshape, _mid_shape_v));
CHECK_OR_FALSE(extract_const(back_perm, _back_perm_v));
CHECK_OR_FALSE(extract_shape(_back_transpose, _out_shape_v));

auto all_known = [](const std::vector<int32_t> &v) -> bool {
for (auto i : v)
Expand All @@ -308,7 +331,9 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,

// check c3
CHECK_OR_FALSE(all_known(_in_shape_v));
CHECK_OR_FALSE(all_known(_front_shape_v));
CHECK_OR_FALSE(all_known(_mid_shape_v));
CHECK_OR_FALSE(all_known(_out_shape_v));

return true;
}
Expand Down Expand Up @@ -360,14 +385,33 @@ bool TaggedShapeAnalyzer::can_remove_transposes()
assert(_in != nullptr && _front_transpose != nullptr && _mid_reshape != nullptr &&
_back_transpose != nullptr);

init_shape_with_tag(_in_shape_v);

analyze_transpose(_front_perm_v);
auto count_not_1_dim = [](const std::vector<int32_t> &shape) -> int {
int count = 0;
for (auto i : shape)
{
if (i != 1)
count++;
}
return count;
};

if (not analyze_reshape(_mid_shape_v))
return false;
auto in_rank = count_not_1_dim(_in_shape_v);
auto out_rank = count_not_1_dim(_out_shape_v);

analyze_transpose(_back_perm_v);
if (in_rank >= out_rank)
{
init_shape_with_tag(_in_shape_v);
analyze_transpose(_front_perm_v);
CHECK_OR_FALSE(analyze_reshape(_mid_shape_v));
analyze_transpose(_back_perm_v);
}
else
{
init_shape_with_tag(_out_shape_v);
analyze_transpose(inverse_perm(_back_perm_v));
CHECK_OR_FALSE(analyze_reshape(_front_shape_v));
analyze_transpose(inverse_perm(_front_perm_v));
}

if (not verify_tag())
return false;
Expand Down Expand Up @@ -460,8 +504,9 @@ namespace luci
* BEFORE
*
* Current pass only targets below cases:
* - in.rank() >= out.rank()
* - 'Reshape' used to reduce N dimension into one (e.g. A x B x C => A x BC)
* or
* - 'Reshape' devides a single dimension into consecutive N dimensions. (e.g. ABC => A x B x C)
*
*
* [CircleNode] [CircleConst]
Expand Down
120 changes: 120 additions & 0 deletions compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,123 @@ TEST(RemoveUnnecessaryTransposeNetPass, incomplete_reshape_pattern2_NEG)
EXPECT_FALSE(pass.run(g.g()));
EXPECT_FALSE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, rank_expansion_pattern1)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (1, 16384, 512)
* |
* (1, 512, 16384)
* |
* (1, 512, 128, 128)
* |
* (1, 128, 128, 512)
*/
g.init_whole_graph(/*in*/ {1, 16384, 512}, /*perm*/ {0, 2, 1}, /*reshape*/ {1, 512, 128, 128},
/*perm*/ {0, 2, 3, 1}, /*out*/ {1, 128, 128, 512});
EXPECT_TRUE(pass.run(g.g()));
EXPECT_TRUE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, rank_expansion_pattern2)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (1, 1, 10, 1)
* |
* (1, 10, 1, 1)
* |
* (2, 1, 5, 1)
* |
* (2, 5, 1, 1)
*/
g.init_whole_graph(/*in*/ {1, 1, 10, 1}, /*perm*/ {0, 3, 1, 2}, /*reshape*/ {2, 1, 5, 1},
/*perm*/ {0, 2, 1, 3}, /*out*/ {2, 5, 1, 1});
EXPECT_TRUE(pass.run(g.g()));
EXPECT_TRUE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, eff_rank_expansion_pattern3)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (2, 1, 12, 1)
* |
* (2, 12, 1, 1)
* |
* (1, 2, 3, 4)
* |
* (2, 1, 3, 4)
*/
g.init_whole_graph(/*in*/ {2, 1, 12, 1}, /*perm*/ {0, 2, 1, 3}, /*reshape*/ {1, 2, 3, 4},
/*perm*/ {1, 0, 2, 3}, /*out*/ {2, 1, 3, 4});
EXPECT_TRUE(pass.run(g.g()));
EXPECT_TRUE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, rank_expansion_pattern1_NEG)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (1, 6, 49)
* |
* (1, 49, 6)
* |
* (1, 7, 6, 7)
* |
* (7, 1, 6, 7)
*/
g.init_whole_graph(/*in*/ {1, 6, 7}, /*perm*/ {0, 2, 1}, /*reshape*/ {1, 7, 6, 7},
/*perm*/ {1, 0, 2, 3}, /*out*/ {7, 1, 6, 7});
EXPECT_FALSE(pass.run(g.g()));
EXPECT_FALSE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, rank_expansion_pattern2_NEG)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (1, 1, 10, 1)
* |
* (10, 1, 1, 1)
* |
* (5, 2, 1, 1)
* |
* (2, 5, 1, 1)
*/
g.init_whole_graph(/*in*/ {1, 1, 10, 1}, /*perm*/ {2, 0, 1, 3}, /*reshape*/ {5, 2, 1, 1},
/*perm*/ {1, 0, 2, 3}, /*out*/ {2, 5, 1, 1});
EXPECT_FALSE(pass.run(g.g()));
EXPECT_FALSE(is_transpose_removed(g.g()));
}

TEST(RemoveUnnecessaryTransposeNetPass, rank_expansion_pattern3_NEG)
{
TransposeReshapeTransposeGraph g;
luci::RemoveUnnecessaryTransposeNetPass pass;

/**
* (7, 10)
* |
* (10, 7)
* |
* (2, 5, 7)
* |
* (2, 7, 5)
*/
g.init_whole_graph(/*in*/ {7, 10}, /*perm*/ {1, 0}, /*reshape*/ {2, 5, 7},
/*perm*/ {0, 2, 1}, /*out*/ {2, 7, 5});
EXPECT_FALSE(pass.run(g.g()));
EXPECT_FALSE(is_transpose_removed(g.g()));
}

0 comments on commit fd49fc8

Please sign in to comment.