diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h new file mode 100644 index 000000000000..db6e35ed51a4 --- /dev/null +++ b/velox/common/encode/EncoderUtils.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/base/Status.h" + +namespace facebook::velox::encoding { + +// Padding character used in encoding. +const static char kPadding = '='; + +// Checks if the input Base64 string is padded. +static inline bool isPadded(std::string_view input) { + size_t inputSize{input.size()}; + return (inputSize > 0 && input[inputSize - 1] == kPadding); +} + +// Counts the number of padding characters in encoded input. +static inline size_t numPadding(std::string_view input) { + size_t numPadding{0}; + size_t inputSize{input.size()}; + while (inputSize > 0 && input[inputSize - 1] == kPadding) { + numPadding++; + inputSize--; + } + return numPadding; +} + +// Validate the character in charset with ReverseIndex table +template +constexpr bool checkForwardIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[static_cast(charset[index])] == index) && + (index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true); +} + +// Searches for a character within a charset up to a certain index. +template +constexpr bool findCharacterInCharset( + const Charset& charset, + uint8_t index, + const char targetChar) { + return index < charset.size() && + ((charset[index] == targetChar) || + findCharacterInCharset(charset, index + 1, targetChar)); +} + +// Checks the consistency of a reverse index mapping for a given character set. +template +constexpr bool checkReverseIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[index] == 255 + ? !findCharacterInCharset(charset, 0, static_cast(index)) + : (charset[reverseIndex[index]] == index)) && + (index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true); +} + +template +Status base64ReverseLookup( + char encodedChar, + const ReverseIndexType& reverseIndex, + uint8_t& reverseLookupValue) { + reverseLookupValue = reverseIndex[static_cast(encodedChar)]; + if (reverseLookupValue >= 0x40) { + return Status::UserError(fmt::format( + "decode() - contains invalid character '{}'", + encodedChar)); + } + return Status::OK(); +} + +// Returns the actual size of the decoded data. Will also remove the padding +// length from the 'inputSize'. +static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + decodedSize = 0; + return Status::OK(); + } + + // Check if the input string is padded + if (isPadded(input)) { + // If padded, ensure that the string length is a multiple of the encoded + // block size + if (inputSize % encodedBlockByteSize != 0) { + return Status::UserError("decode() - invalid input string length."); + } + + decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; + auto paddingCount = numPadding(input); + inputSize -= paddingCount; + + // Adjust the needed size by deducting the bytes corresponding to the + // padding from the calculated size. + decodedSize -= + ((paddingCount * binaryBlockByteSize) + (encodedBlockByteSize - 1)) / + encodedBlockByteSize; + } else { + // If not padded, calculate extra bytes, if any + auto extraBytes = inputSize % encodedBlockByteSize; + decodedSize = (inputSize / encodedBlockByteSize) * binaryBlockByteSize; + // Adjust the needed size for extra bytes, if present + if (extraBytes) { + if (extraBytes == 1) { + return Status::UserError("decode() - invalid input string length."); + } + decodedSize += (extraBytes * binaryBlockByteSize) / encodedBlockByteSize; + } + } + return Status::OK(); +} + +// Calculates the encoded size based on input size. +static size_t calculateEncodedSize( + size_t inputSize, + bool includePadding, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + return 0; + } + + // Calculate the output size assuming that we are including padding. + size_t encodedSize = + ((inputSize + binaryBlockByteSize - 1) / binaryBlockByteSize) * + encodedBlockByteSize; + + if (!includePadding) { + // If the padding was not requested, subtract the padding bytes. + size_t remainder = inputSize % binaryBlockByteSize; + if (remainder != 0) { + encodedSize -= (binaryBlockByteSize - remainder); + } + } + return encodedSize; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22..2e1e79ea222e 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_encode_test Base64Test.cpp) +add_executable(velox_common_encode_test Base64Test.cpp EncoderUtilsTests.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test diff --git a/velox/common/encode/tests/EncoderUtilsTests.cpp b/velox/common/encode/tests/EncoderUtilsTests.cpp new file mode 100644 index 000000000000..c64e245a589e --- /dev/null +++ b/velox/common/encode/tests/EncoderUtilsTests.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { +class EncoderUtilsTest : public ::testing::Test { +protected: + const int binaryBlockByteSize = 3; + const int encodedBlockByteSize = 4; +}; + +TEST_F(EncoderUtilsTest, isPadded) { + EXPECT_TRUE(isPadded("ABC=")); + EXPECT_FALSE(isPadded("ABC")); +} + +TEST_F(EncoderUtilsTest, numPadding) { + EXPECT_EQ(0, numPadding("ABC")); + EXPECT_EQ(1, numPadding("ABC=")); + EXPECT_EQ(2, numPadding("AB==")); +} + +TEST_F(EncoderUtilsTest, CalculateDecodedSizeTest) { + size_t inputSize = 8; + size_t decodedSize = 0; + + EXPECT_EQ( + calculateDecodedSize("abcdabcd", inputSize, decodedSize, binaryBlockByteSize, encodedBlockByteSize), + Status::OK()); + EXPECT_EQ(decodedSize, 6); + + inputSize = 8; + EXPECT_EQ( + calculateDecodedSize("abcdab==", inputSize, decodedSize, binaryBlockByteSize, encodedBlockByteSize), + Status::OK()); + EXPECT_EQ(decodedSize, 4); + + EXPECT_EQ( + calculateDecodedSize("abcdab=", inputSize, decodedSize, binaryBlockByteSize, encodedBlockByteSize), + Status::UserError("decode() - invalid input string length.")); +} + +TEST_F(EncoderUtilsTest, CalculateEncodedSizeTest) { + EXPECT_EQ(calculateEncodedSize(3, true, binaryBlockByteSize, encodedBlockByteSize), 4); + EXPECT_EQ(calculateEncodedSize(3, false, binaryBlockByteSize, encodedBlockByteSize), 4); + EXPECT_EQ(calculateEncodedSize(6, true, binaryBlockByteSize, encodedBlockByteSize), 8); + EXPECT_EQ(calculateEncodedSize(0, true, binaryBlockByteSize, encodedBlockByteSize), 0); +} + +TEST_F(EncoderUtilsTest, Base64ReverseLookupTest) { + std::array reverseIndex{}; + reverseIndex.fill(255); + reverseIndex['A'] = 0; + reverseIndex['B'] = 1; + uint8_t reverseLookupValue = 0; + + EXPECT_EQ( + base64ReverseLookup('A', reverseIndex, reverseLookupValue), + Status::OK()); + EXPECT_EQ(reverseLookupValue, 0); + + EXPECT_EQ( + base64ReverseLookup('B', reverseIndex, reverseLookupValue), + Status::OK()); + EXPECT_EQ(reverseLookupValue, 1); + + EXPECT_EQ( + base64ReverseLookup('Z', reverseIndex, reverseLookupValue), + Status::UserError("decode() - contains invalid character 'Z'")); +} + +TEST_F(EncoderUtilsTest, CheckForwardIndexTest) { + constexpr std::string_view charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::array reverseIndex{}; + + for (size_t i = 0; i < charset.size(); ++i) { + reverseIndex[static_cast(charset[i])] = i; + } + + EXPECT_TRUE(checkForwardIndex(63, charset, reverseIndex)); +} + +TEST_F(EncoderUtilsTest, FindCharacterInCharsetTest) { + constexpr std::string_view charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + EXPECT_TRUE(findCharacterInCharset(charset, 0, 'A')); + EXPECT_FALSE(findCharacterInCharset(charset, 0, 'z')); +} + +TEST_F(EncoderUtilsTest, CheckReverseIndexTest) { + constexpr std::string_view charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::array reverseIndex{}; + + reverseIndex.fill(255); + for (size_t i = 0; i < charset.size(); ++i) { + reverseIndex[static_cast(charset[i])] = i; + } + + EXPECT_TRUE(checkReverseIndex(255, charset, reverseIndex)); +} + +} // namespace facebook::velox::encoding