Skip to content

Commit

Permalink
[luci-interpreter] Add check for Reshape (#14446)
Browse files Browse the repository at this point in the history
This adds a check for num_elements.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Dec 13, 2024
1 parent 5316f4f commit 8869179
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
3 changes: 2 additions & 1 deletion compiler/luci-interpreter/src/kernels/Reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ static void resolveUnknownDimension(const Shape &input_shape, Shape *output_shap
output_shape->dim(unknown_dim_index) = num_input_elements / num_output_elements;
num_output_elements *= output_shape->dim(unknown_dim_index);
}
assert(num_output_elements == num_input_elements);

LUCI_INTERPRETER_CHECK(num_output_elements == num_input_elements);
}

Reshape::Reshape(const Tensor *input, const Tensor *shape, Tensor *output)
Expand Down
16 changes: 16 additions & 0 deletions compiler/luci-interpreter/src/kernels/Reshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ TEST_F(ReshapeTest, SupportS16_NEG)
EXPECT_ANY_THROW(kernel.configure());
}

TEST_F(ReshapeTest, NumElementsMismatch_NEG)
{
Shape input_shape{1, 2, 3};
std::vector<float> input_data{1, 2, 3, 4, 5, 6};
Shape shape_shape{2};
std::vector<int32_t> shape_data{1, 7};
Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
Tensor shape_tensor =
makeInputTensor<DataType::S32>(shape_shape, shape_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

Reshape kernel(&input_tensor, &shape_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}

} // namespace
} // namespace kernels
} // namespace luci_interpreter

0 comments on commit 8869179

Please sign in to comment.