From 886917941ed3c5826e932c879f4b18ca23c01f26 Mon Sep 17 00:00:00 2001 From: Hyukjin Jeong Date: Fri, 13 Dec 2024 11:27:42 +0900 Subject: [PATCH] [luci-interpreter] Add check for Reshape (#14446) This adds a check for num_elements. ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong --- .../luci-interpreter/src/kernels/Reshape.cpp | 3 ++- .../src/kernels/Reshape.test.cpp | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/compiler/luci-interpreter/src/kernels/Reshape.cpp b/compiler/luci-interpreter/src/kernels/Reshape.cpp index d3234e48323..1e2a9d6eb0a 100644 --- a/compiler/luci-interpreter/src/kernels/Reshape.cpp +++ b/compiler/luci-interpreter/src/kernels/Reshape.cpp @@ -77,7 +77,8 @@ static void resolveUnknownDimension(const Shape &input_shape, Shape *output_shap output_shape->dim(unknown_dim_index) = num_input_elements / num_output_elements; num_output_elements *= output_shape->dim(unknown_dim_index); } - assert(num_output_elements == num_input_elements); + + LUCI_INTERPRETER_CHECK(num_output_elements == num_input_elements); } Reshape::Reshape(const Tensor *input, const Tensor *shape, Tensor *output) diff --git a/compiler/luci-interpreter/src/kernels/Reshape.test.cpp b/compiler/luci-interpreter/src/kernels/Reshape.test.cpp index 7c0522ebef0..14d9bfdafdb 100644 --- a/compiler/luci-interpreter/src/kernels/Reshape.test.cpp +++ b/compiler/luci-interpreter/src/kernels/Reshape.test.cpp @@ -113,6 +113,22 @@ TEST_F(ReshapeTest, SupportS16_NEG) EXPECT_ANY_THROW(kernel.configure()); } +TEST_F(ReshapeTest, NumElementsMismatch_NEG) +{ + Shape input_shape{1, 2, 3}; + std::vector input_data{1, 2, 3, 4, 5, 6}; + Shape shape_shape{2}; + std::vector shape_data{1, 7}; + Tensor input_tensor = + makeInputTensor(input_shape, input_data, _memory_manager.get()); + Tensor shape_tensor = + makeInputTensor(shape_shape, shape_data, _memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); + + Reshape kernel(&input_tensor, &shape_tensor, &output_tensor); + EXPECT_ANY_THROW(kernel.configure()); +} + } // namespace } // namespace kernels } // namespace luci_interpreter