Skip to content

Commit

Permalink
Adds check to ensure input tensors match model tensor size & type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728613394
  • Loading branch information
MediaPipe Team authored and copybara-github committed Feb 19, 2025
1 parent b012383 commit ecbfda1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 32 deletions.
40 changes: 28 additions & 12 deletions mediapipe/calculators/tensor/inference_calculator_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,15 @@ absl::Status CopyTensorToTfLiteTensor<char>(const Tensor& input_tensor,
template <typename T>
absl::Status CopyTfLiteTensorToTensor(const TfLiteTensor& tflite_tensor,
Tensor& output_tensor) {
MP_RETURN_IF_ERROR(TensorDimsAndTypeEqual(output_tensor, tflite_tensor));
auto output_tensor_view = output_tensor.GetCpuWriteView();
T* output_tensor_buffer = output_tensor_view.buffer<T>();
RET_CHECK(output_tensor_buffer) << "Output tensor buffer is null.";
RET_CHECK_EQ(tflite_tensor.type, tflite::typeToTfLiteType<T>())
.SetCode(absl::StatusCode::kInvalidArgument)
<< "TfLite tensor type and requested output type do not match.";
const Tensor::ElementType output_tensor_type = output_tensor.element_type();
RET_CHECK(output_tensor_type == tflite_tensor.type)
.SetCode(absl::StatusCode::kInvalidArgument)
<< "Output and TfLiteTensor types do not match";
const void* local_tensor_buffer = tflite_tensor.data.raw;
RET_CHECK(local_tensor_buffer) << "TfLiteTensor tensor buffer is null.";
if (!TfLiteIntArrayEqualsArray(tflite_tensor.dims,
output_tensor.shape().dims.size(),
output_tensor.shape().dims.data())) {
return absl::InvalidArgumentError(
absl::StrCat("TfLiteTensor and Tensor shape do not match: ",
GetTfLiteTensorDebugInfo(tflite_tensor), " vs. ",
GetMpTensorDebugInfo(output_tensor)));
}

std::memcpy(output_tensor_buffer, local_tensor_buffer, output_tensor.bytes());
return absl::OkStatus();
Expand Down Expand Up @@ -420,4 +409,31 @@ absl::StatusOr<Tensor> CreateTensorWithTfLiteTensorSpecs(
TfLiteTypeGetName(reference_tflite_tensor.type)));
}

absl::Status TensorDimsAndTypeEqual(const Tensor& mp_tensor,
const TfLiteTensor& tflite_tensor) {
const Tensor::ElementType output_tensor_type = mp_tensor.element_type();
RET_CHECK(output_tensor_type == tflite_tensor.type)
.SetCode(absl::StatusCode::kInvalidArgument)
<< absl::StrFormat(
"MediaPipe and TfLite tensor type do not match: MediaPipe tensor "
"type %s vs TfLite tensor type %s.",
GetTensorTypeString(output_tensor_type),
TfLiteTypeGetName(tflite_tensor.type));
if (!TfLiteIntArrayEqualsArray(tflite_tensor.dims,
mp_tensor.shape().dims.size(),
mp_tensor.shape().dims.data())) {
return absl::InvalidArgumentError(
absl::StrCat("TfLiteTensor and Tensor shape do not match: ",
GetTfLiteTensorDebugInfo(tflite_tensor), " vs. ",
GetMpTensorDebugInfo(mp_tensor)));
}
RET_CHECK_EQ(mp_tensor.bytes(), tflite_tensor.bytes)
.SetCode(absl::StatusCode::kInvalidArgument)
<< absl::StrFormat(
"MediaPipe and TfLite tensor bytes do not match: Mediapipe "
"tensor bytes %d vs TfLite tensor bytes %d.",
mp_tensor.bytes(), tflite_tensor.bytes);
return absl::OkStatus();
}

} // namespace mediapipe
5 changes: 5 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/flags/declare.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -83,6 +84,10 @@ absl::StatusOr<Tensor> CreateTensorWithTfLiteTensorSpecs(
const TfLiteTensor& reference_tflite_tensor,
MemoryManager* memory_manager = nullptr, int alignment = 0);

// Checks that MP and TfLite tensor size and type matches.
absl::Status TensorDimsAndTypeEqual(const Tensor& mp_tensor,
const TfLiteTensor& tflite_tensor);

} // namespace mediapipe

#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
113 changes: 93 additions & 20 deletions mediapipe/calculators/tensor/inference_calculator_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -108,6 +110,53 @@ std::vector<char> TfLiteInputTensorData<char>(const Interpreter& interpreter,
return std::vector<char>(str.begin(), str.end());
}

static absl::StatusOr<int> GetSizeOfType(TfLiteType type) {
switch (type) {
case kTfLiteFloat16:
return sizeof(float) / 2;
case kTfLiteFloat32:
return sizeof(float);
case kTfLiteUInt8:
return sizeof(uint8_t);
case kTfLiteInt8:
return sizeof(int8_t);
case kTfLiteInt32:
return sizeof(int32_t);
case kTfLiteBool:
return sizeof(bool);
case kTfLiteInt64:
return sizeof(int64_t);
default:
break;
}
return absl::InvalidArgumentError("Unsupported TfLite type.");
}

static auto CreateTfLiteTensor(TfLiteType type,
const std::vector<int>& dimensions, float scale,
float zero_point) {
auto dealloc = [](TfLiteTensor* tensor) {
TfLiteIntArrayFree(tensor->dims);
delete (tensor);
};
std::unique_ptr<TfLiteTensor, decltype(dealloc)> tflite_tensor(
new TfLiteTensor, dealloc);
tflite_tensor->type = type;
tflite_tensor->allocation_type = kTfLiteDynamic;
tflite_tensor->quantization.type = kTfLiteNoQuantization;
TfLiteIntArray* dims = tflite::ConvertVectorToTfLiteIntArray(dimensions);
const int num_elements =
std::accumulate(std::begin(dimensions), std::end(dimensions), 1.0,
std::multiplies<int>());
auto size_of_type = GetSizeOfType(type);
ABSL_CHECK_OK(size_of_type);
tflite_tensor->dims = dims;
tflite_tensor->bytes = *size_of_type * num_elements;
tflite_tensor->params.scale = scale;
tflite_tensor->params.zero_point = zero_point;
return tflite_tensor;
}

class InferenceCalculatorUtilsTest : public ::testing::Test {
protected:
void TearDown() override {
Expand Down Expand Up @@ -395,7 +444,7 @@ TEST_F(InferenceCalculatorUtilsTest,
CopyTfLiteTensorIntoCpuOutput(*m.GetOutputTensor(0), tensor);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("Output and TfLiteTensor types do not match"));
HasSubstr("MediaPipe and TfLite tensor type do not match"));
}

TEST_F(InferenceCalculatorUtilsTest,
Expand Down Expand Up @@ -550,6 +599,48 @@ TEST_F(InferenceCalculatorUtilsTest, ShouldNotConfirmTfLiteMemoryAlignment) {
sizeof(int32_t)));
}

TEST_F(InferenceCalculatorUtilsTest, TensorDimsAndTypeEqualOk) {
std::vector<int> dims = {1, 2, 3, 4};
Tensor tensor(ElementType::kInt32, Tensor::Shape(dims));
auto tflite_tensor =
CreateTfLiteTensor(TfLiteType::kTfLiteInt32, dims, /*scale=*/1.0f,
/*zero_point=*/0.0f);
MP_EXPECT_OK(TensorDimsAndTypeEqual(tensor, *tflite_tensor));
}

TEST_F(InferenceCalculatorUtilsTest,
TensorDimsAndTypeEqualDiffersInDimensions) {
Tensor tensor(ElementType::kInt32, Tensor::Shape({1, 2, 3, 4}));
std::vector<int> dims = {1, 2, 3};
auto tflite_tensor =
CreateTfLiteTensor(TfLiteType::kTfLiteInt32, dims, /*scale=*/1.0f,
/*zero_point=*/0.0f);
EXPECT_THAT(TensorDimsAndTypeEqual(tensor, *tflite_tensor).message(),
HasSubstr("TfLiteTensor and Tensor shape do not match"));
}

TEST_F(InferenceCalculatorUtilsTest, TensorDimsAndTypeEqualDiffersInType) {
Tensor tensor(ElementType::kInt32, Tensor::Shape({1, 2, 3, 4}));
std::vector<int> dims = {1, 2, 3, 4};
auto tflite_tensor =
CreateTfLiteTensor(TfLiteType::kTfLiteFloat32, dims, /*scale=*/1.0f,
/*zero_point=*/0.0f);
EXPECT_THAT(TensorDimsAndTypeEqual(tensor, *tflite_tensor).message(),
HasSubstr("MediaPipe and TfLite tensor type do not match"));
}

TEST_F(InferenceCalculatorUtilsTest, TensorDimsAndTypeEqualDiffersInSize) {
Tensor tensor(ElementType::kInt32, Tensor::Shape({1, 2, 3, 4}));
std::vector<int> dims = {1, 2, 3, 4};
auto tflite_tensor =
CreateTfLiteTensor(TfLiteType::kTfLiteInt32, dims, /*scale=*/1.0f,
/*zero_point=*/0.0f);
// Override the size to make it different.
tflite_tensor->bytes = 100;
EXPECT_THAT(TensorDimsAndTypeEqual(tensor, *tflite_tensor).message(),
HasSubstr("MediaPipe and TfLite tensor bytes do not match"));
}

static std::vector<std::pair<TfLiteType, Tensor::ElementType>>
GetTensorTypePairs() {
return {{TfLiteType::kTfLiteFloat16, Tensor::ElementType::kFloat32},
Expand All @@ -560,24 +651,6 @@ GetTensorTypePairs() {
{TfLiteType::kTfLiteBool, Tensor::ElementType::kBool}};
}

static auto CreateTfLiteTensor(TfLiteType type, int num_elements, float scale,
float zero_point) {
auto dealloc = [](TfLiteTensor* tensor) {
TfLiteIntArrayFree(tensor->dims);
delete (tensor);
};
std::unique_ptr<TfLiteTensor, decltype(dealloc)> tflite_tensor(
new TfLiteTensor, dealloc);
tflite_tensor->type = type;
tflite_tensor->allocation_type = kTfLiteDynamic;
tflite_tensor->quantization.type = kTfLiteNoQuantization;
TfLiteIntArray* dims = tflite::ConvertVectorToTfLiteIntArray({num_elements});
tflite_tensor->dims = dims;
tflite_tensor->params.scale = scale;
tflite_tensor->params.zero_point = zero_point;
return tflite_tensor;
}

class AllocateTensorWithTfLiteTensorSpecsTest
: public ::testing::TestWithParam<
std::pair<TfLiteType, Tensor::ElementType>> {};
Expand All @@ -586,7 +659,7 @@ TEST_P(AllocateTensorWithTfLiteTensorSpecsTest,
ShouldAllocateTensorWithTfLiteTensorSpecs) {
const auto& config = GetParam();
const auto tflite_tensor =
CreateTfLiteTensor(config.first, /*num_elements=*/4,
CreateTfLiteTensor(config.first, std::vector<int>({4}),
/*scale=*/2.0f, /*zero_point=*/3.0f);
MP_ASSERT_OK_AND_ASSIGN(Tensor mp_tensor,
CreateTensorWithTfLiteTensorSpecs(
Expand Down

0 comments on commit ecbfda1

Please sign in to comment.