From ef1e44513892127629abf4c608d789b108bc6828 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Wed, 7 Aug 2024 18:44:55 +0530 Subject: [PATCH 1/2] Introduce utility class for encoding --- velox/common/encode/EncoderUtils.h | 167 ++++++++++++++++++ velox/common/encode/tests/CMakeLists.txt | 2 +- .../common/encode/tests/EncoderUtilsTests.cpp | 35 ++++ 3 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 velox/common/encode/EncoderUtils.h create mode 100644 velox/common/encode/tests/EncoderUtilsTests.cpp diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h new file mode 100644 index 000000000000..7c5a8a5b09e5 --- /dev/null +++ b/velox/common/encode/EncoderUtils.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/base/Status.h" + +namespace facebook::velox::encoding { + +/// Padding character used in encoding. +const static char kPadding = '='; + +// Checks if the input Base64 string is padded. +static inline bool isPadded(std::string_view input) { + size_t inputSize{input.size()}; + return (inputSize > 0 && input[inputSize - 1] == kPadding); +} + +// Counts the number of padding characters in encoded input. +static inline size_t numPadding(std::string_view input) { + size_t numPadding{0}; + size_t inputSize{input.size()}; + while (inputSize > 0 && input[inputSize - 1] == kPadding) { + numPadding++; + inputSize--; + } + return numPadding; +} + +// Validate the character in charset with ReverseIndex table +template +constexpr bool checkForwardIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[static_cast(charset[index])] == index) && + (index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true); +} + +// Searches for a character within a charset up to a certain index. +template +constexpr bool findCharacterInCharset( + const Charset& charset, + uint8_t index, + const char targetChar) { + return index < charset.size() && + ((charset[index] == targetChar) || + findCharacterInCharset(charset, index + 1, targetChar)); +} + +// Checks the consistency of a reverse index mapping for a given character set. +template +constexpr bool checkReverseIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[index] == 255 + ? !findCharacterInCharset(charset, 0, static_cast(index)) + : (charset[reverseIndex[index]] == index)) && + (index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true); +} + +template +uint8_t reverseLookup( + char encodedChar, + const ReverseIndexType& reverseIndex, + Status& status, + uint8_t kBase) { + auto curr = reverseIndex[static_cast(encodedChar)]; + if (curr >= kBase) { + status = + Status::UserError("invalid input string: contains invalid characters."); + return 0; // Return 0 or any other error code indicating failure + } + return curr; +} + +// Returns the actual size of the decoded data. Will also remove the padding +// length from the 'inputSize'. +static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + decodedSize = 0; + return Status::OK(); + } + + // Check if the input string is padded + if (isPadded(input)) { + // If padded, ensure that the string length is a multiple of the encoded + // block size + if (inputSize % encodedBlockByteSize != 0) { + return Status::UserError( + "decode() - invalid input string: " + "string length is not a multiple of 4."); + } + + decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; + auto paddingCount = numPadding(input); + inputSize -= paddingCount; + + // Adjust the needed size by deducting the bytes corresponding to the + // padding from the calculated size. + decodedSize -= + ((paddingCount * binaryBlockByteSize) + (encodedBlockByteSize - 1)) / + encodedBlockByteSize; + } else { + // If not padded, calculate extra bytes, if any + auto extraBytes = inputSize % encodedBlockByteSize; + decodedSize = (inputSize / encodedBlockByteSize) * binaryBlockByteSize; + // Adjust the needed size for extra bytes, if present + if (extraBytes) { + if (extraBytes == 1) { + return Status::UserError( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4."); + } + decodedSize += (extraBytes * binaryBlockByteSize) / encodedBlockByteSize; + } + } + + return Status::OK(); +} + +// Calculates the encoded size based on input size. +static size_t calculateEncodedSize( + size_t inputSize, + bool includePadding, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + return 0; + } + + // Calculate the output size assuming that we are including padding. + size_t encodedSize = + ((inputSize + binaryBlockByteSize - 1) / binaryBlockByteSize) * + encodedBlockByteSize; + + if (!includePadding) { + // If the padding was not requested, subtract the padding bytes. + size_t remainder = inputSize % binaryBlockByteSize; + if (remainder != 0) { + encodedSize -= (binaryBlockByteSize - remainder); + } + } + + return encodedSize; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22..2e1e79ea222e 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_encode_test Base64Test.cpp) +add_executable(velox_common_encode_test Base64Test.cpp EncoderUtilsTests.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test diff --git a/velox/common/encode/tests/EncoderUtilsTests.cpp b/velox/common/encode/tests/EncoderUtilsTests.cpp new file mode 100644 index 000000000000..e112f8125349 --- /dev/null +++ b/velox/common/encode/tests/EncoderUtilsTests.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { +class EncoderUtilsTest : public ::testing::Test {}; + +TEST_F(EncoderUtilsTest, isPadded) { + EXPECT_TRUE(isPadded("ABC=")); + EXPECT_FALSE(isPadded("ABC")); +} + +TEST_F(EncoderUtilsTest, numPadding) { + EXPECT_EQ(0, numPadding("ABC")); + EXPECT_EQ(1, numPadding("ABC=")); + EXPECT_EQ(2, numPadding("AB==")); +} + +} // namespace facebook::velox::encoding From 6a06e63172de4398f7a2e0a5d801ea703ad03708 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Sat, 5 Oct 2024 10:41:36 +0530 Subject: [PATCH 2/2] Add presto function `to_base32` and `from_base32` --- velox/common/encode/Base32.cpp | 304 ++++++++++++++++++ velox/common/encode/Base32.h | 61 ++++ velox/common/encode/CMakeLists.txt | 2 +- velox/common/encode/EncoderUtils.h | 8 +- velox/docs/functions/presto/binary.rst | 41 +++ velox/functions/prestosql/BinaryFunctions.h | 39 +++ .../BinaryFunctionsRegistration.cpp | 7 + .../prestosql/tests/BinaryFunctionsTest.cpp | 71 ++++ 8 files changed, 526 insertions(+), 7 deletions(-) create mode 100644 velox/common/encode/Base32.cpp create mode 100644 velox/common/encode/Base32.h diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp new file mode 100644 index 000000000000..846fb113b007 --- /dev/null +++ b/velox/common/encode/Base32.cpp @@ -0,0 +1,304 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/encode/Base32.h" + +#include + +namespace facebook::velox::encoding { + +// Constants defining the size in bytes of binary and encoded blocks for Base32 +// encoding. +// Size of a binary block in bytes (5 bytes = 40 bits) +constexpr static int kBinaryBlockByteSize = 5; +// Size of an encoded block in bytes (8 bytes = 40 bits) +constexpr static int kEncodedBlockByteSize = 8; + +constexpr Base32::Charset kBase32Charset = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', + 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', '2', '3', '4', '5', '6', '7'}; + +constexpr Base32::ReverseIndex kBase32ReverseIndexTable = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +// Verify that for each 32 entries in kBase32Charset, the corresponding entry +// in kBase32ReverseIndexTable is correct. +static_assert( + checkForwardIndex( + sizeof(kBase32Charset) / 2 - 1, + kBase32Charset, + kBase32ReverseIndexTable), + "kBase32Charset has incorrect entries"); + +// Verify that for every entry in kBase32ReverseIndexTable, the corresponding +// entry in kBase32Charset is correct. +static_assert( + checkReverseIndex( + sizeof(kBase32ReverseIndexTable) - 1, + kBase32Charset, + kBase32ReverseIndexTable), + "kBase32ReverseIndexTable has incorrect entries."); + +// static +Status Base32::encode(std::string_view input, std::string& output) { + return encodeImpl(input, true, output); +} + +// static +template +Status +Base32::encodeImpl(const T& input, bool includePadding, std::string& output) { + auto inputSize = input.size(); + if (inputSize == 0) { + output.clear(); + return Status::OK(); + } + + // Calculate the output size and resize the string beforehand + size_t outputSize = calculateEncodedSize( + inputSize, includePadding, kBinaryBlockByteSize, kEncodedBlockByteSize); + output.resize(outputSize); + + // Use a pointer to write into the pre-allocated buffer + auto outputPointer = output.data(); + auto inputIterator = input.begin(); + + // Process 5-byte (40-bit) blocks, split into 8 groups of 5 bits + for (; inputSize > 4; inputSize -= 5) { + uint64_t currentBlock = static_cast(*inputIterator++) << 32; + currentBlock |= static_cast(*inputIterator++) << 24; + currentBlock |= static_cast(*inputIterator++) << 16; + currentBlock |= static_cast(*inputIterator++) << 8; + currentBlock |= static_cast(*inputIterator++); + + *outputPointer++ = kBase32Charset[(currentBlock >> 35) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 30) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 25) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 20) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 15) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 10) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 5) & 0x1f]; + *outputPointer++ = kBase32Charset[currentBlock & 0x1f]; + } + + // Handle remaining bytes (1 to 4 bytes) + if (inputSize > 0) { + uint64_t currentBlock = static_cast(*inputIterator++) << 32; + *outputPointer++ = kBase32Charset[(currentBlock >> 35) & 0x1f]; + + if (inputSize > 3) { + currentBlock |= static_cast(*inputIterator++) << 24; + currentBlock |= static_cast(*inputIterator++) << 16; + currentBlock |= static_cast(*inputIterator++) << 8; + + *outputPointer++ = kBase32Charset[(currentBlock >> 30) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 25) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 20) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 15) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 10) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 5) & 0x1f]; + if (includePadding) { + *outputPointer++ = kPadding; + } + } else if (inputSize > 2) { + currentBlock |= static_cast(*inputIterator++) << 24; + currentBlock |= static_cast(*inputIterator++) << 16; + + *outputPointer++ = kBase32Charset[(currentBlock >> 30) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 25) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 20) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 15) & 0x1f]; + if (includePadding) { + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + } + } else if (inputSize > 1) { + currentBlock |= static_cast(*inputIterator++) << 24; + + *outputPointer++ = kBase32Charset[(currentBlock >> 30) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 25) & 0x1f]; + *outputPointer++ = kBase32Charset[(currentBlock >> 20) & 0x1f]; + if (includePadding) { + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + } + } else { + *outputPointer++ = kBase32Charset[(currentBlock >> 30) & 0x1f]; + if (includePadding) { + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + *outputPointer++ = kPadding; + } + } + } + + return Status::OK(); +} + +// static +uint8_t Base32::base32ReverseLookup(char encodedChar, Status& status) { + return reverseLookup( + encodedChar, kBase32ReverseIndexTable, status, kCharsetSize); +} + +// static +Status Base32::decode(std::string_view input, std::string& output) { + return decodeImpl(input, output); +} + +// static +Status Base32::decodeImpl(std::string_view input, std::string& output) { + size_t inputSize = input.size(); + + // If input is empty, clear output and return OK status. + if (inputSize == 0) { + output.clear(); + return Status::OK(); + } + + // Calculate the decoded size based on the input size. + size_t decodedSize; + auto status = calculateDecodedSize( + input, + inputSize, + decodedSize, + kBinaryBlockByteSize, + kEncodedBlockByteSize); + if (!status.ok()) { + return status; + } + + // Resize the output to accommodate the decoded data. + output.resize(decodedSize); + + const char* inputPtr = input.data(); + char* outputPtr = output.data(); + Status lookupStatus; + + // Process full blocks of 8 characters + size_t fullBlockCount = inputSize / 8; + for (size_t i = 0; i < fullBlockCount; ++i) { + uint64_t inputBlock = 0; + + // Decode 8 characters into a 40-bit block + for (int shift = 35, j = 0; j < 8; ++j, shift -= 5) { + uint64_t value = base32ReverseLookup(inputPtr[j], lookupStatus); + if (!lookupStatus.ok()) { + return lookupStatus; + } + inputBlock |= (value << shift); + } + + // Write the decoded block to the output + outputPtr[0] = static_cast((inputBlock >> 32) & 0xFF); + outputPtr[1] = static_cast((inputBlock >> 24) & 0xFF); + outputPtr[2] = static_cast((inputBlock >> 16) & 0xFF); + outputPtr[3] = static_cast((inputBlock >> 8) & 0xFF); + outputPtr[4] = static_cast(inputBlock & 0xFF); + + inputPtr += 8; + outputPtr += 5; + } + + // Handle remaining characters (2, 4, 5, 7) + size_t remaining = inputSize % 8; + if (remaining >= 2) { + uint64_t inputBlock = 0; + + // Decode the first two characters + inputBlock |= + (static_cast(base32ReverseLookup(inputPtr[0], lookupStatus)) + << 35); + inputBlock |= + (static_cast(base32ReverseLookup(inputPtr[1], lookupStatus)) + << 30); + + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputPtr[0] = static_cast((inputBlock >> 32) & 0xFF); + + if (remaining > 2) { + // Decode the next two characters + inputBlock |= (base32ReverseLookup(inputPtr[2], lookupStatus) << 25); + inputBlock |= (base32ReverseLookup(inputPtr[3], lookupStatus) << 20); + + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputPtr[1] = static_cast((inputBlock >> 24) & 0xFF); + + if (remaining > 4) { + // Decode the next character + inputBlock |= (base32ReverseLookup(inputPtr[4], lookupStatus) << 15); + + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputPtr[2] = static_cast((inputBlock >> 16) & 0xFF); + + if (remaining > 5) { + // Decode the next two characters + inputBlock |= (base32ReverseLookup(inputPtr[5], lookupStatus) << 10); + inputBlock |= (base32ReverseLookup(inputPtr[6], lookupStatus) << 5); + + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputPtr[3] = static_cast((inputBlock >> 8) & 0xFF); + + if (remaining > 7) { + // Decode the last character + inputBlock |= base32ReverseLookup(inputPtr[7], lookupStatus); + + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputPtr[4] = static_cast(inputBlock & 0xFF); + } + } + } + } + } + + // Return status + return Status::OK(); +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base32.h b/velox/common/encode/Base32.h new file mode 100644 index 000000000000..612f25e69801 --- /dev/null +++ b/velox/common/encode/Base32.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/common/base/Status.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { + +class Base32 { + public: + static const size_t kCharsetSize = 32; + static const size_t kReverseIndexSize = 256; + + /// Character set used for encoding purposes. + /// Contains specific characters that form the encoding scheme. + using Charset = std::array; + + /// Reverse lookup table for decoding purposes. + /// Maps each possible encoded character to its corresponding numeric value + /// within the encoding base. + using ReverseIndex = std::array; + + /// Encodes the specified number of characters from the 'input' and writes the + /// result to the 'output'. + static Status encode(std::string_view input, std::string& output); + + /// Decodes the specified number of characters from the 'input' and writes the + /// result to the 'output'. + static Status decode(std::string_view input, std::string& output); + + private: + // Performs a reverse lookup in the reverse index to retrieve the original + // index of a character in the base. + static uint8_t base32ReverseLookup(char encodedChar, Status& status); + + // Encodes the specified input using the provided charset. + template + static Status + encodeImpl(const T& input, bool includePadding, std::string& output); + + // Decodes the specified input using the provided reverse lookup table. + static Status decodeImpl(std::string_view input, std::string& output); +}; + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index 501c690c476b..b897399daf8a 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -16,5 +16,5 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_encode Base64.cpp) +velox_add_library(velox_encode Base32.cpp Base64.cpp) velox_link_libraries(velox_encode PUBLIC Folly::folly) diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h index 7c5a8a5b09e5..663ef22fc94d 100644 --- a/velox/common/encode/EncoderUtils.h +++ b/velox/common/encode/EncoderUtils.h @@ -106,9 +106,7 @@ static Status calculateDecodedSize( // If padded, ensure that the string length is a multiple of the encoded // block size if (inputSize % encodedBlockByteSize != 0) { - return Status::UserError( - "decode() - invalid input string: " - "string length is not a multiple of 4."); + return Status::UserError("decode() - invalid input string length."); } decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; @@ -127,9 +125,7 @@ static Status calculateDecodedSize( // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - return Status::UserError( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4."); + return Status::UserError("decode() - invalid input string length."); } decodedSize += (extraBytes * binaryBlockByteSize) / encodedBlockByteSize; } diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 8b4ddc26832e..2efeada08ae7 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -33,6 +33,29 @@ Binary Functions Decodes ``string`` data from the base64 encoded representation using the `URL safe alphabet `_ into a varbinary. +.. function:: from_base64(string) -> varbinary + + Decodes a Base64-encoded ``string`` back into its original binary form. + This function can handle both padded and non-padded Base64 encoded strings. + Partially padded Base64 strings will result in a "UserError" status being returned. + + Examples + -------- + Query with padded Base64 string: + :: + SELECT from_base64('SGVsbG8gV29ybGQ='); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + + Query with non-padded Base64 string: + :: + SELECT from_base64('SGVsbG8gV29ybGQ'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + + Query with partially padded Base64 string: + :: + SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- Error: Base64::decode() - invalid input string: length is not a multiple of 4. + + In the examples above, both fully padded and non-padded Base64 strings ('SGVsbG8gV29ybGQ=' and 'SGVsbG8gV29ybGQ') decode to the binary representation of the text 'Hello World'. + The partially padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will result in a "UserError" status indicating the Base64 string is invalid. + .. function:: from_big_endian_32(varbinary) -> integer Decodes ``integer`` value from a 32-bit 2’s complement big endian ``binary``. @@ -123,6 +146,24 @@ Binary Functions Encodes ``binary`` into a base64 string representation. +.. function:: to_base32(varbinary) -> string + + Encodes a binary ``varbinary`` value into its Base32 string representation. + This function generates padded Base32 strings by default. + + Examples + -------- + Query to encode a binary value to a padded Base32 string: + :: + SELECT to_base32(ARRAY[72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100]); -- 'JBSWY3DPEBLW64TMMQ======' + + Query to encode a binary value with fewer bytes: + :: + SELECT to_base32(ARRAY[104, 101, 108, 108, 111]); -- 'NBSWY3DP' + + In the above examples, the binary array `[72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100]` is encoded to the padded Base32 string 'JBSWY3DPEBLW64TMMQ======'. + The binary array `[104, 101, 108, 108, 111]` is encoded to 'NBSWY3DP'. + .. function:: to_base64url(binary) -> varchar Encodes ``binary`` into a base64 string representation using the `URL safe alphabet `_. diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index c9495daff1bb..492f3b217e27 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -21,6 +21,7 @@ #include "folly/ssl/OpenSSLHash.h" #include "velox/common/base/BitUtil.h" +#include "velox/common/encode/Base32.h" #include "velox/common/encode/Base64.h" #include "velox/external/md5/md5.h" #include "velox/functions/Udf.h" @@ -328,6 +329,44 @@ struct ToBase64UrlFunction { } }; +template +struct ToBase32Function { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { + std::string_view inputView(input.data(), input.size()); + std::string output; + auto status = encoding::Base32::encode(inputView, output); + if (!status.ok()) { + return status; + } + result.resize(output.size()); + std::memcpy(result.data(), output.data(), output.size()); + return Status::OK(); + } +}; + +template +struct FromBase32Function { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + // T can be either arg_type or arg_type. These are the + // same, but hard-coding one of them might be confusing. + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { + std::string_view inputView(input.data(), input.size()); + std::string output; + auto status = encoding::Base32::decode(inputView, output); + if (!status.ok()) { + return status; + } + result.resize(output.size()); + std::memcpy(result.data(), output.data(), output.size()); + return Status::OK(); + } +}; + template struct FromBigEndian32 { VELOX_DEFINE_FUNCTION_TYPES(T); diff --git a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp index 6f098ebadc51..6ac4ff1bce75 100644 --- a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp @@ -56,6 +56,13 @@ void registerSimpleFunctions(const std::string& prefix) { registerFunction( {prefix + "from_base64url"}); + registerFunction( + {prefix + "to_base32"}); + registerFunction( + {prefix + "from_base32"}); + registerFunction( + {prefix + "from_base32"}); + registerFunction( {prefix + "from_big_endian_32"}); registerFunction( diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index a1582e9f5eb0..22d4d5e8df59 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -474,6 +474,77 @@ TEST_F(BinaryFunctionsTest, fromBase64Url) { EXPECT_THROW(fromBase64Url("YQ=/"), VeloxUserError); } +TEST_F(BinaryFunctionsTest, toBase32) { + const auto toBase32 = [&](std::optional value) { + return evaluateOnce("to_base32(cast(c0 as varbinary))", value); + }; + + EXPECT_EQ(std::nullopt, toBase32(std::nullopt)); + EXPECT_EQ("", toBase32("")); + EXPECT_EQ("ME======", toBase32("a")); + EXPECT_EQ("MFRGG===", toBase32("abc")); + EXPECT_EQ("NZXQ====", toBase32("no")); + EXPECT_EQ("O5SQ====", toBase32("we")); + EXPECT_EQ("MRRDE===", toBase32("db2")); + EXPECT_EQ("MNQWWZI=", toBase32("cake")); + EXPECT_EQ("NNSWK3Q=", toBase32("keen")); + EXPECT_EQ("GEZDGNA=", toBase32("1234")); + EXPECT_EQ("NBSWY3DPEB3W64TMMQ======", toBase32("hello world")); + EXPECT_EQ( + "JBSWY3DPEBLW64TMMQQGM4TPNUQFMZLMN54CC===", + toBase32("Hello World from Velox!")); +} + +TEST_F(BinaryFunctionsTest, fromBase32) { + const auto fromBase32 = [&](std::optional value) { + // from_base32 allows VARCHAR and VARBINARY inputs. + auto result = + evaluateOnce("from_base32(c0)", VARCHAR(), value); + auto otherResult = + evaluateOnce("from_base32(c0)", VARBINARY(), value); + + VELOX_CHECK_EQ(result.has_value(), otherResult.has_value()); + + if (!result.has_value()) { + return result; + } + + VELOX_CHECK_EQ(result.value(), otherResult.value()); + return result; + }; + + EXPECT_EQ(std::nullopt, fromBase32(std::nullopt)); + EXPECT_EQ("", fromBase32("")); + EXPECT_EQ("a", fromBase32("ME======")); + EXPECT_EQ("ab", fromBase32("MFRA====")); + EXPECT_EQ("abc", fromBase32("MFRGG===")); + EXPECT_EQ("db2", fromBase32("MRRDE===")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA=")); + EXPECT_EQ("hello world", fromBase32("NBSWY3DPEB3W64TMMQ======")); + EXPECT_EQ( + "Hello World from Velox!", + fromBase32("JBSWY3DPEBLW64TMMQQGM4TPNUQFMZLMN54CC===")); + + // Try encoded strings without padding + EXPECT_EQ("a", fromBase32("ME")); + EXPECT_EQ("ab", fromBase32("MFRA")); + EXPECT_EQ("abc", fromBase32("MFRGG")); + EXPECT_EQ("db2", fromBase32("MRRDE")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA")); + EXPECT_EQ("1234", fromBase32("GEZDGNA")); + EXPECT_EQ("abcde", fromBase32("MFRGGZDF")); + EXPECT_EQ("abcdef", fromBase32("MFRGGZDFMY")); + + VELOX_ASSERT_USER_THROW( + fromBase32("1="), "decode() - invalid input string length."); + VELOX_ASSERT_USER_THROW( + fromBase32("M1======"), + "invalid input string: contains invalid characters."); + VELOX_ASSERT_USER_THROW( + fromBase32("J$======"), + "invalid input string: contains invalid characters."); +} + TEST_F(BinaryFunctionsTest, fromBigEndian32) { const auto fromBigEndian32 = [&](const std::optional& arg) { return evaluateOnce("from_big_endian_32(c0)", VARBINARY(), arg);