From 833a103d4a2b3f06e9ac5b40b105f908d6e0400b Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Thu, 1 Feb 2024 11:42:44 +0530 Subject: [PATCH] Add from_base32 presto function --- velox/common/encode/Base32.cpp | 191 ++++++++++++++++++ velox/common/encode/Base32.h | 52 +++++ velox/common/encode/CMakeLists.txt | 2 +- velox/common/encode/tests/Base32Test.cpp | 56 +++++ velox/common/encode/tests/CMakeLists.txt | 3 +- velox/docs/functions/presto/binary.rst | 22 ++ velox/functions/prestosql/BinaryFunctions.h | 17 ++ .../BinaryFunctionsRegistration.cpp | 2 + .../prestosql/tests/BinaryFunctionsTest.cpp | 36 ++++ 9 files changed, 379 insertions(+), 2 deletions(-) create mode 100644 velox/common/encode/Base32.cpp create mode 100644 velox/common/encode/Base32.h create mode 100644 velox/common/encode/tests/Base32Test.cpp diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp new file mode 100644 index 0000000000000..5778afabe09ce --- /dev/null +++ b/velox/common/encode/Base32.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/encode/Base32.h" + +#include + +namespace facebook::velox::encoding { + +// Encoding base to be used. +constexpr static int kBase = 32; + +// Constants defining the size in bytes of binary and encoded blocks for Base32 +// encoding. +// Size of a binary block in bytes (5 bytes = 40 bits) +constexpr static int kBinaryBlockByteSize = 5; +// Size of an encoded block in bytes (8 bytes = 40 bits) +constexpr static int kEncodedBlockByteSize = 8; + +constexpr Charset kBase32Charset = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', + 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', '2', '3', '4', '5', '6', '7'}; + +constexpr ReverseIndex kBase32ReverseIndexTable = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +/// Verify that for each 32 entries in kBase32Charset, the corresponding entry +/// in kBase32ReverseIndexTable is correct. +static_assert( + checkForwardIndex( + sizeof(kBase32Charset) / 2 - 1, + kBase32Charset, + kBase32ReverseIndexTable), + "kBase32Charset has incorrect entries"); + +/// Verify that for every entry in kBase32ReverseIndexTable, the corresponding +/// entry in kBase32Charset is correct. +static_assert( + checkReverseIndex( + sizeof(kBase32ReverseIndexTable) - 1, + kBase32Charset, + kBase, + kBase32ReverseIndexTable), + "kBase32ReverseIndexTable has incorrect entries."); + +size_t Base32::calculateDecodedSize(const char* data, size_t& size) { + if (size == 0) { + return 0; + } + + // Check if the input data is padded + if (isPadded(data, size)) { + /// If padded, ensure that the string length is a multiple of the encoded + /// block size. + if (size % kEncodedBlockByteSize != 0) { + VELOX_USER_FAIL( + "Base32::decode() - invalid input string: " + "string length is not a multiple of 8."); + } + + auto needed = (size * kBinaryBlockByteSize) / kEncodedBlockByteSize; + auto padding = numPadding(data, size); + size -= padding; + + // Adjust the needed size by deducting the bytes corresponding to the + // padding from the calculated size. + return needed - + ((padding * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / + kEncodedBlockByteSize; + } else { + // If not padded, calculate extra bytes, if any. + auto extra = size % kEncodedBlockByteSize; + auto needed = (size / kEncodedBlockByteSize) * kBinaryBlockByteSize; + + // Adjust the needed size for extra bytes, if present. + if (extra) { + if ((extra == 6) || (extra == 3) || (extra == 1)) { + VELOX_USER_FAIL( + "Base32::decode() - invalid input string: " + "string length cannot be 6, 3 or 1 more than a multiple of 8."); + } + needed += (extra * kBinaryBlockByteSize) / kEncodedBlockByteSize; + } + + return needed; + } +} + +size_t +Base32::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { + return decodeImpl(src, src_len, dst, dst_len, kBase32ReverseIndexTable); +} + +size_t Base32::decodeImpl( + const char* src, + size_t src_len, + char* dst, + size_t dst_len, + const ReverseIndex& reverse_lookup) { + if (!src_len) { + return 0; + } + + auto needed = calculateDecodedSize(src, src_len); + if (dst_len < needed) { + VELOX_USER_FAIL( + "Base32::decode() - invalid output string: " + "output string is too small."); + } + + // Handle full groups of 8 characters. + for (; src_len > 8; src_len -= 8, src += 8, dst += 5) { + /// Each character of the 8 bytes encode 5 bits of the original, grab each + /// with the appropriate shifts to rebuild the original and then split that + /// back into the original 8 bit bytes. + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30) | + (baseReverseLookup(kBase, src[2], reverse_lookup) << 25) | + (baseReverseLookup(kBase, src[3], reverse_lookup) << 20) | + (baseReverseLookup(kBase, src[4], reverse_lookup) << 15) | + (baseReverseLookup(kBase, src[5], reverse_lookup) << 10) | + (baseReverseLookup(kBase, src[6], reverse_lookup) << 5) | + baseReverseLookup(kBase, src[7], reverse_lookup); + dst[0] = (last >> 32) & 0xff; + dst[1] = (last >> 24) & 0xff; + dst[2] = (last >> 16) & 0xff; + dst[3] = (last >> 8) & 0xff; + dst[4] = last & 0xff; + } + + /// Handle the last 2, 4, 5, 7 or 8 characters. This is similar to the above, + /// but the last characters may or may not exist. + DCHECK(src_len >= 2); + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30); + dst[0] = (last >> 32) & 0xff; + if (src_len > 2) { + last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 25; + last |= baseReverseLookup(kBase, src[3], reverse_lookup) << 20; + dst[1] = (last >> 24) & 0xff; + if (src_len > 4) { + last |= baseReverseLookup(kBase, src[4], reverse_lookup) << 15; + dst[2] = (last >> 16) & 0xff; + if (src_len > 5) { + last |= baseReverseLookup(kBase, src[5], reverse_lookup) << 10; + last |= baseReverseLookup(kBase, src[6], reverse_lookup) << 5; + dst[3] = (last >> 8) & 0xff; + if (src_len > 7) { + last |= baseReverseLookup(kBase, src[7], reverse_lookup); + dst[4] = last & 0xff; + } + } + } + } + + return needed; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base32.h b/velox/common/encode/Base32.h new file mode 100644 index 0000000000000..a3b1688679b87 --- /dev/null +++ b/velox/common/encode/Base32.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { + +class Base32 { + public: + /// Returns decoded size for the specified input. Adjusts the 'size' to + /// subtract the length of the padding, if exists. + static size_t calculateDecodedSize(const char* data, size_t& size); + + /// Decodes the specified number of characters from the 'src' and writes the + /// result to the 'dst'. The destination must have enough space, e.g. as + /// returned by the calculateDecodedSize(). + static size_t + decode(const char* src, size_t src_len, char* dst, size_t dst_len); + + private: + /// Decodes the specified number of base 32 encoded characters from the 'src' + /// and writes to 'dst' + static size_t decodeImpl( + const char* src, + size_t src_len, + char* dst, + size_t dst_len, + const ReverseIndex& table); +}; + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index bc27527e14ace..9709df2486807 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -16,5 +16,5 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -add_library(velox_encode Base64.cpp) +add_library(velox_encode Base32.cpp Base64.cpp) target_link_libraries(velox_encode PUBLIC Folly::folly) diff --git a/velox/common/encode/tests/Base32Test.cpp b/velox/common/encode/tests/Base32Test.cpp new file mode 100644 index 0000000000000..3792d25a51230 --- /dev/null +++ b/velox/common/encode/tests/Base32Test.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/encode/Base32.h" +#include +#include "velox/common/base/tests/GTestUtils.h" + +namespace facebook::velox::encoding { + +class Base32Test : public ::testing::Test {}; + +TEST_F(Base32Test, calculateDecodedSizeProperSize) { + struct TestCase { + std::string encoded; + size_t initial_size; + int expected_decoded; + size_t expected_size; + }; + + std::vector test_cases = { + {"ME======", 8, 1, 2}, + {"ME", 2, 1, 2}, + {"MFRA====", 8, 2, 4}, + {"MFRGG===", 8, 3, 5}, + {"NBSWY3DPEB3W64TMMQ======", 24, 11, 18}, + {"NBSWY3DPEB3W64TMMQ", 18, 11, 18}}; + + for (const auto& test : test_cases) { + size_t encoded_size = test.initial_size; + EXPECT_EQ( + test.expected_decoded, + Base32::calculateDecodedSize(test.encoded.c_str(), encoded_size)); + EXPECT_EQ(test.expected_size, encoded_size); + } +} + +TEST_F(Base32Test, errorWhenDecodedStringPartiallyPadded) { + size_t encoded_size = 9; + EXPECT_THROW( + Base32::calculateDecodedSize("MFRA====", encoded_size), VeloxUserError); +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 548ce580e4691..5c63d172999f7 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_encode_test Base64Test.cpp EncoderUtilsTests.cpp) +add_executable(velox_common_encode_test Base32Test.cpp Base64Test.cpp + EncoderUtilsTests.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 8b4ddc26832ea..4492237e53fad 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -41,6 +41,28 @@ Binary Functions Decodes ``bigint`` value from a 64-bit 2’s complement big endian ``binary``. +.. function:: from_base32(string) -> varbinary + + Decodes a Base32-encoded ``string`` back into its original binary form. + This function can handle both padded and non-padded Base32 encoded strings. Partially padded Base32 strings will result in an error. + + Examples + -------- + Query with padded Base32 string: + :: + SELECT from_base32('JBSWY3DPEBLW64TMMQ======'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + + Query with non-padded Base32 string: + :: + SELECT from_base32('JBSWY3DPEBLW64TMMQ'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] + + Query with partially padded Base32 string: + :: + SELECT from_base32('JBSWY3DPEBLW64TM=='); -- Error: Base32::decode() - invalid input string: length is not a multiple of 8. + + In the examples above, both fully padded and non-padded Base32 strings ('JBSWY3DPEBLW64TMMQ======' and 'JBSWY3DPEBLW64TMMQ') decode to the binary representation of the text 'Hello World'. + The partially padded Base32 string 'JBSWY3DPEBLW64TM==' will lead to a decoding error. + .. function:: from_hex(string) -> varbinary Decodes binary data from the hex encoded ``string``. diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index d733553ce4a92..cdeb527aa75ee 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -21,6 +21,7 @@ #include "folly/ssl/OpenSSLHash.h" #include "velox/common/base/BitUtil.h" +#include "velox/common/encode/Base32.h" #include "velox/common/encode/Base64.h" #include "velox/external/md5/md5.h" #include "velox/functions/Udf.h" @@ -324,6 +325,22 @@ struct ToBase64UrlFunction { } }; +template +struct FromBase32Function { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + // T can be either arg_type or arg_type. These are the + // same, but hard-coding one of them might be confusing. + template + FOLLY_ALWAYS_INLINE void call(out_type& result, const T& input) { + auto inputSize = input.size(); + result.resize( + encoding::Base32::calculateDecodedSize(input.data(), inputSize)); + encoding::Base32::decode( + input.data(), inputSize, result.data(), result.size()); + } +}; + template struct FromBigEndian32 { VELOX_DEFINE_FUNCTION_TYPES(T); diff --git a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp index 6f098ebadc51a..77a8a0e31fbfb 100644 --- a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp @@ -55,6 +55,8 @@ void registerSimpleFunctions(const std::string& prefix) { {prefix + "to_base64url"}); registerFunction( {prefix + "from_base64url"}); + registerFunction( + {prefix + "from_base32"}); registerFunction( {prefix + "from_big_endian_32"}); diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index 72ef47e22b105..861e2271f80c8 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -477,6 +477,42 @@ TEST_F(BinaryFunctionsTest, fromBase64Url) { EXPECT_THROW(fromBase64Url("YQ=/"), VeloxUserError); } +TEST_F(BinaryFunctionsTest, fromBase32) { + const auto fromBase32 = [&](std::optional value) { + return evaluateOnce("from_base32(c0)", value); + }; + + EXPECT_EQ(std::nullopt, fromBase32(std::nullopt)); + EXPECT_EQ("", fromBase32("")); + EXPECT_EQ("a", fromBase32("ME======")); + EXPECT_EQ("ab", fromBase32("MFRA====")); + EXPECT_EQ("abc", fromBase32("MFRGG===")); + EXPECT_EQ("db2", fromBase32("MRRDE===")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA=")); + EXPECT_EQ("hello world", fromBase32("NBSWY3DPEB3W64TMMQ======")); + EXPECT_EQ( + "Hello World from Velox!", + fromBase32("JBSWY3DPEBLW64TMMQQGM4TPNUQFMZLMN54CC===")); + + // Try encoded strings without padding + EXPECT_EQ("a", fromBase32("ME")); + EXPECT_EQ("ab", fromBase32("MFRA")); + EXPECT_EQ("abc", fromBase32("MFRGG")); + EXPECT_EQ("db2", fromBase32("MRRDE")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA")); + EXPECT_EQ("1234", fromBase32("GEZDGNA")); + EXPECT_EQ("abcde", fromBase32("MFRGGZDF")); + EXPECT_EQ("abcdef", fromBase32("MFRGGZDFMY")); + + // Check with invaild encoded strings + EXPECT_THROW(fromBase32("1="), VeloxUserError); + EXPECT_THROW(fromBase32("M1======"), VeloxUserError); + + VELOX_ASSERT_THROW( + fromBase32("J1======"), + "decode() - invalid input string: invalid characters"); +} + TEST_F(BinaryFunctionsTest, fromBigEndian32) { const auto fromBigEndian32 = [&](const std::optional& arg) { return evaluateOnce("from_big_endian_32(c0)", VARBINARY(), arg);