-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT][luci/service] Support dynamic shape for reshape #13935
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,14 @@ | |
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Service/CircleShapeInference.h" | ||
#include "Check.h" | ||
|
||
#include "CircleShapeInferenceHelper.h" | ||
#include "CircleCloneNode.h" | ||
|
||
#include <oops/InternalExn.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
|
@@ -34,4 +40,134 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node) | |
return cloned; | ||
} | ||
|
||
namespace sinf | ||
{ | ||
|
||
/** | ||
* @note CircleReshape always has two inputs: tensor and shape. | ||
* The shape input can be CircleConst, CircleOutputDummy, or CircleNode. | ||
* - If the shape input is CircleConst, the shape is inferred from the constant. | ||
* - If the shape input is CircleOutputDummy, the shape is inferred from | ||
* the attribute if it exists. If the attribute does not exist, | ||
* the shape is inferred from the node iteself. | ||
* - If the shape input is CircleNode, the shape is not inferred. | ||
*/ | ||
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) | ||
zetwhite marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
const loco::DataType S32 = loco::DataType::S32; | ||
|
||
// CircleReshape node must have reshape/shape | ||
if (node->shape() == nullptr) | ||
{ | ||
INTERNAL_EXN("2nd input shape() should not be nullptr"); | ||
} | ||
|
||
bool should_infer = true; | ||
loco::TensorShape output_shape; | ||
{ | ||
// Check if reshape/shape is CircleConst | ||
auto const_input = dynamic_cast<luci::CircleConst *>(node->shape()); | ||
if (const_input != nullptr) | ||
{ | ||
output_shape.rank(const_input->size<S32>()); | ||
|
||
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) | ||
{ | ||
output_shape.dim(axis) = const_input->at<S32>(axis); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed it...! Thanks for pointing really important thing out :) I'll make the change. |
||
if (const_input->at<S32>(axis) < 0) | ||
{ | ||
output_shape.dim(axis).unset(); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
// Check if reshape/shape is CircleOutputDummy | ||
auto dummy_input = dynamic_cast<luci::CircleOutputDummy *>(node->shape()); | ||
if (dummy_input != nullptr) | ||
{ | ||
if (node->newShape()->rank() > 0) | ||
{ | ||
output_shape.rank(node->newShape()->rank()); | ||
|
||
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) | ||
{ | ||
output_shape.dim(axis) = node->newShape()->dim(axis); | ||
if (node->newShape()->dim(axis) < 0) | ||
{ | ||
output_shape.dim(axis).unset(); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
output_shape = circle_shape(node); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, to check my understanding, This part corresponds to 'get shape from own shape`. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, you're right :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I've read the related PRs (#1554, #1519) but I'm not sure how to handle this recipe. Do you think we need to discuss about this recipe on #13927 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to #13927 (comment), we may be able to revise the recipe first...! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @llFreetimell Thank you a lot for helping us 😄 I think @jongwonyang followed what i suggested in here : #13927 (comment)
It was hard to make a policy about "no attribute, no shape input" case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't mean to say it should be done this way. For now, I chose an easy way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understood :D |
||
} | ||
} | ||
else | ||
{ | ||
// Check if reshape/shape is CircleNode | ||
auto node_input = dynamic_cast<luci::CircleNode *>(node->shape()); | ||
if (node_input != nullptr) | ||
{ | ||
output_shape.rank(node_input->dim(0).value()); | ||
|
||
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) | ||
{ | ||
output_shape.dim(axis).unset(); | ||
} | ||
|
||
should_infer = false; | ||
} | ||
} | ||
} | ||
} | ||
|
||
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor()); | ||
const auto input_shape = circle_shape(input); | ||
uint32_t input_element_count = 1; | ||
for (uint32_t axis = 0; axis < input_shape.rank(); ++axis) | ||
{ | ||
if (input_shape.dim(axis).known()) | ||
{ | ||
input_element_count *= input_shape.dim(axis).value(); | ||
} | ||
else | ||
{ | ||
should_infer = false; | ||
break; | ||
} | ||
} | ||
|
||
if (should_infer) | ||
{ | ||
uint32_t output_element_count = 1; | ||
uint32_t unknown_dim_index = UINT32_MAX; | ||
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) | ||
{ | ||
if (output_shape.dim(dim_index).known() == false) | ||
{ | ||
if (unknown_dim_index != UINT32_MAX) | ||
{ | ||
INTERNAL_EXN("More than one unknown dimension"); | ||
} | ||
unknown_dim_index = dim_index; | ||
} | ||
else | ||
{ | ||
const uint32_t dim_value = output_shape.dim(dim_index).value(); | ||
output_element_count *= dim_value; | ||
} | ||
} | ||
if (unknown_dim_index != UINT32_MAX) | ||
{ | ||
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; | ||
} | ||
} | ||
|
||
return output_shape; | ||
} | ||
|
||
} // namespace sinf | ||
|
||
} // namespace luci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question,
What does the
CircleOutputDummy
mean?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While importing the circle file,
CircleOutputDummy
is added as node's shape input when there is noshape_by_input
and noshape_by_attr
.ONE/compiler/luci/import/src/Nodes/CircleReshape.cpp
Lines 79 to 92 in 487afbd