Skip to content

Commit

Permalink
[luci] Fix ForwardRehsapeToUnaryOpPass bug (#14474)
Browse files Browse the repository at this point in the history
This updates rank of constant after reshape is forwarded.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Dec 18, 2024
1 parent 5044bb4 commit 17081b4
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
13 changes: 13 additions & 0 deletions compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,29 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleMul *div,
if (not new_reshape)
return false;

const auto prev = loco::must_cast<luci::CircleNode *>(reshape->tensor());

// Reshape can change rank of tensor, so we need to update constant value accordingly.
assert(const_value->size<loco::DataType::FLOAT32>() == 1);
auto cloned_const = clone(const_value);
cloned_const->rank(prev->rank());
for (uint32_t i = 0; i < prev->rank(); ++i)
{
cloned_const->dim(i).set(1);
}

// reconnect network
loco::replace(div).with(new_reshape);
if (div->x() == const_value)
{
div->x(cloned_const);
div->y(reshape->tensor());
}
else
{
assert(div->y() == const_value);
div->x(reshape->tensor());
div->y(cloned_const);
}
new_reshape->tensor(div);

Expand Down
106 changes: 106 additions & 0 deletions compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,56 @@ class ReshapeLogisticGraphlet
luci::CircleConst *_reshape_shape = nullptr;
};

class ReshapeMulGraphlet
{
public:
ReshapeMulGraphlet() = default;

public:
void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out)
{
std::vector<uint32_t> shape_out_v = shape_out;

_reshape_shape = g->nodes()->create<luci::CircleConst>();
_reshape = g->nodes()->create<luci::CircleReshape>();
_mul = g->nodes()->create<luci::CircleMul>();
_const = g->nodes()->create<luci::CircleConst>();

_reshape_shape->dtype(loco::DataType::S32);
_reshape_shape->rank(1);
_reshape_shape->dim(0).set(shape_out_v.size());
_reshape_shape->shape_status(luci::ShapeStatus::VALID);

// values
const auto size = shape_out_v.size();
_reshape_shape->size<loco::DataType::S32>(size);
for (uint32_t i = 0; i < size; i++)
_reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i];

_const->dtype(loco::DataType::FLOAT32);
_const->rank(size);
uint32_t numel = 1;
for (uint32_t i = 0; i < size; i++)
{
_const->dim(i).set(1);
}
_const->size<loco::DataType::FLOAT32>(1);
_const->at<loco::DataType::FLOAT32>(0) = 1.0;
_const->shape_status(luci::ShapeStatus::VALID);

_reshape_shape->name("reshape_shape");
_reshape->name("reshape");
_mul->name("mul");
_const->name("const");
}

protected:
luci::CircleReshape *_reshape = nullptr;
luci::CircleMul *_mul = nullptr;
luci::CircleConst *_const = nullptr;
luci::CircleConst *_reshape_shape = nullptr;
};

class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet
{
public:
Expand Down Expand Up @@ -141,6 +191,27 @@ class ForwardReshapeToLogisticGraph : public TestIOGraph, public ReshapeLogistic
}
};

class ForwardReshapeToMulGraph : public TestIOGraph, public ReshapeMulGraphlet
{
public:
ForwardReshapeToMulGraph() = default;

public:
void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
{
TestIOGraph::init(shape_in, shape_out);
ReshapeMulGraphlet::init(g(), shape_in, shape_out);

// connect network
_reshape->tensor(input());
_reshape->shape(_reshape_shape);
_mul->x(_reshape);
_mul->y(_const);

output()->from(_mul);
}
};

class ForwardReshapeToNegGraphTest : public ::testing::Test
{
public:
Expand Down Expand Up @@ -173,6 +244,22 @@ class ForwardReshapeToLogisticGraphTest : public ::testing::Test
luci::ForwardReshapeToUnaryOpPass _pass;
};

class ForwardReshapeToMulGraphTest : public ::testing::Test
{
public:
ForwardReshapeToMulGraphTest() = default;

void run_pass(void)
{
while (_pass.run(_graph.g()))
;
}

protected:
ForwardReshapeToMulGraph _graph;
luci::ForwardReshapeToUnaryOpPass _pass;
};

/**
* Simple graph for test
*
Expand Down Expand Up @@ -318,6 +405,25 @@ TEST_F(ForwardReshapeToLogisticGraphTest, forward)
ASSERT_NE(nullptr, log);
}

TEST_F(ForwardReshapeToMulGraphTest, forward_rank_update)
{
_graph.init({1, 2, 3}, {1, 1, 2, 3});

run_pass();

auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
ASSERT_NE(nullptr, reshape);

auto mul = dynamic_cast<luci::CircleMul *>(reshape->tensor());
ASSERT_NE(nullptr, mul);

// Check mul's const rank == input rank (3)
auto const_mul = dynamic_cast<luci::CircleConst *>(mul->y());
ASSERT_NE(nullptr, const_mul);

ASSERT_EQ(3, const_mul->rank());
}

TEST(FuseMulWithDivPassTest, forward_reshape_to_mean_pattern)
{
ForwardReshapeToMeanPatternTestGraph g;
Expand Down

0 comments on commit 17081b4

Please sign in to comment.