From 3834919d241fe976b24f17fcd62b96ba40f8bfa7 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Tue, 21 Nov 2023 12:05:13 +0530 Subject: [PATCH] Refactor base64 and additional testcase for without padding --- velox/common/encode/Base.cpp | 50 +++++ velox/common/encode/Base.h | 92 ++++++++ velox/common/encode/Base64.cpp | 198 +++++++----------- velox/common/encode/Base64.h | 55 ++--- velox/common/encode/CMakeLists.txt | 4 +- velox/functions/prestosql/BinaryFunctions.h | 13 +- .../prestosql/tests/BinaryFunctionsTest.cpp | 5 + 7 files changed, 250 insertions(+), 167 deletions(-) create mode 100644 velox/common/encode/Base.cpp create mode 100644 velox/common/encode/Base.h diff --git a/velox/common/encode/Base.cpp b/velox/common/encode/Base.cpp new file mode 100644 index 0000000000000..e14d164bce542 --- /dev/null +++ b/velox/common/encode/Base.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/encode/Base.h" + +#include + +namespace facebook::velox::encoding { + +bool Base::isPadded(const char* data, size_t len) { + return (len > 0 && data[len - 1] == kBasePad) ? true : false; +} + +size_t Base::countPadding(const char* src, size_t len) { + size_t padding_count = 0; + while (len > 0 && src[len - 1] == kBasePad) { + padding_count++; + len--; + } + + return padding_count; +} + +uint8_t Base::baseReverseLookup( + const int base, + char p, + const Base::ReverseIndex& reverse_lookup) { + auto curr = reverse_lookup[(uint8_t)p]; + // Value of encoded character shall be less than base. + if (curr >= base) { + throw BaseException( + "Base::decode() - invalid input string: invalid characters"); + } + + return curr; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base.h b/velox/common/encode/Base.h new file mode 100644 index 0000000000000..fdd710b62425b --- /dev/null +++ b/velox/common/encode/Base.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +namespace facebook::velox::encoding { + +class BaseException : public std::exception { + public: + explicit BaseException(const char* msg) : msg_(msg) {} + const char* what() const noexcept override { + return msg_; + } + + protected: + const char* msg_; +}; + +/// Base class for all binary encoding scheme for reversibly translating between +/// byte sequences and printable ASCII strings specified by RFC 4648. +class Base { + public: + using Charset = std::array; + using ReverseIndex = std::array; + + // Checks is there padding in encoded data + static bool isPadded(const char* src, size_t len); + + // Counts the number of padding characters in encoded data. + static size_t countPadding(const char* src, size_t len); + + // Gets value corresponding to an encoded character + static uint8_t + baseReverseLookup(const int base, char p, const ReverseIndex& table); + + // Padding character used in encoding + constexpr static char kBasePad = '='; +}; + +// Validate the character in charset with ReverseIndex table +constexpr bool checkForwardIndex( + uint8_t idx, + const Base::Charset& charset, + const Base::ReverseIndex& table) { + return (table[static_cast(charset[idx])] == idx) && + (idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true); +} + +/// Similar to strchr(), but for null-terminated const strings. +/// Another difference is that we do not consider "\0" to be present in the +/// string. +/// Returns true if "str" contains the character c. +constexpr bool constCharsetContains( + const Base::Charset& charset, + int base, + uint8_t idx, + const char c) { + return idx < base && + ((charset[idx] == c) || constCharsetContains(charset, base, idx + 1, c)); +} + +// Validate the value in ReverseIndex table with charset. +constexpr bool checkReverseIndex( + uint8_t idx, + const Base::Charset& charset, + int base, + const Base::ReverseIndex& table) { + return (table[idx] == 255 + ? !constCharsetContains(charset, base, 0, static_cast(idx)) + : (charset[table[idx]] == idx)) && + (idx > 0 ? checkReverseIndex(idx - 1, charset, base, table) : true); +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 85fd843b86a83..09e4f2cf0821b 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -74,67 +74,43 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -constexpr bool checkForwardIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true); -} -// Verify that for every entry in kBase64Charset, the corresponding entry -// in kBase64ReverseIndexTable is correct. +/// Verify that for every entry in kBase64Charset, the corresponding entry +/// in kBase64ReverseIndexTable is correct. static_assert( checkForwardIndex( sizeof(kBase64Charset) - 1, kBase64Charset, kBase64ReverseIndexTable), "kBase64Charset has incorrect entries"); -// Verify that for every entry in kBase64UrlCharset, the corresponding entry -// in kBase64UrlReverseIndexTable is correct. +/// Verify that for every entry in kBase64UrlCharset, the corresponding entry +/// in kBase64UrlReverseIndexTable is correct. static_assert( checkForwardIndex( sizeof(kBase64UrlCharset) - 1, kBase64UrlCharset, kBase64UrlReverseIndexTable), "kBase64UrlCharset has incorrect entries"); -// Similar to strchr(), but for null-terminated const strings. -// Another difference is that we do not consider "\0" to be present in the -// string. -// Returns true if "str" contains the character c. -constexpr bool constCharsetContains( - const Base64::Charset& charset, - uint8_t idx, - const char c) { - return idx < charset.size() && - ((charset[idx] == c) || constCharsetContains(charset, idx + 1, c)); -} -constexpr bool checkReverseIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[idx] == 255 - ? !constCharsetContains(charset, 0, static_cast(idx)) - : (charset[table[idx]] == idx)) && - (idx > 0 ? checkReverseIndex(idx - 1, charset, table) : true); -} -// Verify that for every entry in kBase64ReverseIndexTable, the corresponding -// entry in kBase64Charset is correct. + +/// Verify that for every entry in kBase64ReverseIndexTable, the corresponding +/// entry in kBase64Charset is correct. static_assert( checkReverseIndex( sizeof(kBase64ReverseIndexTable) - 1, kBase64Charset, + Base64::kBase, kBase64ReverseIndexTable), "kBase64ReverseIndexTable has incorrect entries."); -// Verify that for every entry in kBase64ReverseIndexTable, the corresponding -// entry in kBase64Charset is correct. -// We can't run this check as the URL version has two duplicate entries so that -// the url decoder can handle url encodings and default encodings -// static_assert( -// checkReverseIndex( -// sizeof(kBase64UrlReverseIndexTable) - 1, -// kBase64UrlCharset, -// kBase64UrlReverseIndexTable), -// "kBase64UrlReverseIndexTable has incorrect entries."); + +/// Verify that for every entry in kBase64ReverseIndexTable, the corresponding +/// entry in kBase64Charset is correct. +/// We can't run this check as the URL version has two duplicate entries so that +/// the url decoder can handle url encodings and default encodings +/// static_assert( +/// checkReverseIndex( +/// sizeof(kBase64UrlReverseIndexTable) - 1, +/// kBase64UrlCharset, +/// kBase64UrlReverseIndexTable), +/// "kBase64UrlReverseIndexTable has incorrect entries."); template /* static */ std::string @@ -187,8 +163,8 @@ template auto wp = out; auto it = data.begin(); - // For each group of 3 bytes (24 bits) in the input, split that into - // 4 groups of 6 bits and encode that using the supplied charset lookup + /// For each group of 3 bytes (24 bits) in the input, split that into + /// 4 groups of 6 bits and encode that using the supplied charset lookup for (; len > 2; len -= 3) { uint32_t curr = uint8_t(*it++) << 16; curr |= uint8_t(*it++) << 8; @@ -201,9 +177,9 @@ template } if (len > 0) { - // We have either 1 or 2 input bytes left. Encode this similar to the - // above (assuming 0 for all other bytes). Optionally append the '=' - // character if it is requested. + /// We have either 1 or 2 input bytes left. Encode this similar to the + /// above (assuming 0 for all other bytes). Optionally append the '=' + /// character if it is requested. uint32_t curr = uint8_t(*it++) << 16; *wp++ = charset[(curr >> 18) & 0x3f]; if (len > 1) { @@ -211,13 +187,13 @@ template *wp++ = charset[(curr >> 12) & 0x3f]; *wp++ = charset[(curr >> 6) & 0x3f]; if (include_pad) { - *wp = kBase64Pad; + *wp = kBasePad; } } else { *wp++ = charset[(curr >> 12) & 0x3f]; if (include_pad) { - *wp++ = kBase64Pad; - *wp = kBase64Pad; + *wp++ = kBasePad; + *wp = kBasePad; } } } @@ -246,8 +222,8 @@ class IOBufWrapper { explicit Iterator(const folly::IOBuf* data) : cs_(data) {} Iterator& operator++(int32_t) { - // This is a noop since reading from the Cursor has already moved the - // position + /// This is a noop since reading from the Cursor has already moved the + /// position return *this; } @@ -304,71 +280,54 @@ void Base64::decode( output.resize(out_len); } -// static void Base64::decode(const char* data, size_t size, char* output) { size_t out_len = size / 4 * 3; Base64::decode(data, size, output, out_len); } -uint8_t Base64::Base64ReverseLookup( - char p, - const Base64::ReverseIndex& reverse_lookup) { - auto curr = reverse_lookup[(uint8_t)p]; - if (curr >= 0x40) { - throw Base64Exception( - "Base64::decode() - invalid input string: invalid characters"); - } - - return curr; -} - size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { - return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true); + return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); } -// static -size_t -Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) { +size_t Base64::calculateDecodedSize(const char* data, size_t& size) { if (size == 0) { return 0; } - auto needed = (size / 4) * 3; - if (withPadding) { - // If the pad characters are included then the source string must be a - // multiple of 4 and we can query the end of the string to see how much - // padding exists. - if (size % 4 != 0) { - throw Base64Exception( - "Base64::decode() - invalid input string: " - "string length is not multiple of 4."); + // If padding doesn't exist, add count for the extra bytes + if (!isPadded(data, size)) { + /// If padding doesn't exist we need to calculate it from the size - if the + /// size % 4 is 0 then we have an even multiple 3 byte chunks in the result + /// if it is 2 then we need 1 more byte in the output. If it is 3 then we + /// need 2 more bytes in the output. It should never be 1. + auto extra = size % kEncodedBlockSize; + auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; + if (extra) { + if (extra == 1) { + throw BaseException( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4."); + } + needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; } - - auto padding = countPadding(data, size); - size -= padding; - return needed - padding; + return needed; } - // If padding doesn't exist we need to calculate it from the size - if the - // size % 4 is 0 then we have an even multiple 3 byte chunks in the result - // if it is 2 then we need 1 more byte in the output. If it is 3 then we - // need 2 more bytes in the output. It should never be 1. - auto extra = size % 4; - if (extra) { - if (extra == 1) { - throw Base64Exception( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4."); - } - return needed + extra - 1; + /// If the pad characters are included then the source string must be a + /// multiple of encoded block size and we can query the end of the string + /// to see how much padding exists. + if (size % kEncodedBlockSize != 0) { + throw BaseException( + "Base64::decode() - invalid input string: " + "string length is not multiple of encoded block size."); } - // Just because we don't need the pad, doesn't mean it is not there. The - // URL decoder should be able to handle the original encoding. - auto padding = countPadding(data, size); + auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; + auto padding = Base::countPadding(data, size); size -= padding; - return needed - padding; + return needed - + ceil((padding * kBinaryBlockSize) / double(kEncodedBlockSize)); } size_t Base64::decodeImpl( @@ -376,44 +335,43 @@ size_t Base64::decodeImpl( size_t src_len, char* dst, size_t dst_len, - const Base64::ReverseIndex& reverse_lookup, - bool include_pad) { + const Base64::ReverseIndex& reverse_lookup) { if (!src_len) { return 0; } - auto needed = calculateDecodedSize(src, src_len, include_pad); + auto needed = calculateDecodedSize(src, src_len); if (dst_len < needed) { - throw Base64Exception( + throw BaseException( "Base64::decode() - invalid output string: " "output string is too small."); } // Handle full groups of 4 characters for (; src_len > 4; src_len -= 4, src += 4, dst += 3) { - // Each character of the 4 encode 6 bits of the original, grab each with - // the appropriate shifts to rebuild the original and then split that back - // into the original 8 bit bytes. - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12) | - (Base64ReverseLookup(src[2], reverse_lookup) << 6) | - Base64ReverseLookup(src[3], reverse_lookup); + /// Each character of the 4 encode 6 bits of the original, grab each with + /// the appropriate shifts to rebuild the original and then split that back + /// into the original 8 bit bytes. + uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase, src[1], reverse_lookup) << 12) | + (baseReverseLookup(kBase, src[2], reverse_lookup) << 6) | + baseReverseLookup(kBase, src[3], reverse_lookup); dst[0] = (last >> 16) & 0xff; dst[1] = (last >> 8) & 0xff; dst[2] = last & 0xff; } - // Handle the last 2-4 characters. This is similar to the above, but the - // last 2 characters may or may not exist. + /// Handle the last 2-4 characters. This is similar to the above, but the + /// last 2 characters may or may not exist. DCHECK(src_len >= 2); - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12); + uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase, src[1], reverse_lookup) << 12); dst[0] = (last >> 16) & 0xff; if (src_len > 2) { - last |= Base64ReverseLookup(src[2], reverse_lookup) << 6; + last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 6; dst[1] = (last >> 8) & 0xff; if (src_len > 3) { - last |= Base64ReverseLookup(src[3], reverse_lookup); + last |= baseReverseLookup(kBase, src[3], reverse_lookup); dst[2] = last & 0xff; } } @@ -437,9 +395,8 @@ void Base64::decodeUrl( const char* src, size_t src_len, char* dst, - size_t dst_len, - bool hasPad) { - decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad); + size_t dst_len) { + decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable); } std::string Base64::decodeUrl(folly::StringPiece encoded) { @@ -458,8 +415,7 @@ void Base64::decodeUrl( payload.second, &output[0], out_len, - kBase64UrlReverseIndexTable, - false); + kBase64UrlReverseIndexTable); output.resize(out_len); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 9888d97e67c54..4d7c696e214c2 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -21,30 +21,17 @@ #include #include +#include "velox/common/encode/Base.h" namespace facebook::velox::encoding { -class Base64Exception : public std::exception { +class Base64 : public Base { public: - explicit Base64Exception(const char* msg) : msg_(msg) {} - const char* what() const noexcept override { - return msg_; - } - - protected: - const char* msg_; -}; - -class Base64 { - public: - using Charset = std::array; - using ReverseIndex = std::array; - static std::string encode(const char* data, size_t len); static std::string encode(folly::StringPiece text); static std::string encode(const folly::IOBuf* text); - /// Returns encoded size for the input of the specified size. + // Returns encoded size for the input of the specified size. static size_t calculateEncodedSize(size_t size, bool withPadding = true); /// Encodes the specified number of characters from the 'data' and writes the @@ -59,8 +46,7 @@ class Base64 { /// Returns decoded size for the specified input. Adjusts the 'size' to /// subtract the length of the padding, if exists. - static size_t - calculateDecodedSize(const char* data, size_t& size, bool withPadding = true); + static size_t calculateDecodedSize(const char* data, size_t& size); /// Decodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -69,7 +55,7 @@ class Base64 { static void decode( const std::pair& payload, - std::string& outp); + std::string& output); /// Encodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -89,23 +75,10 @@ class Base64 { static size_t decode(const char* src, size_t src_len, char* dst, size_t dst_len); - static void decodeUrl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, - bool pad); - - constexpr static char kBase64Pad = '='; + static void + decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); private: - static inline size_t countPadding(const char* src, size_t len) { - DCHECK_GE(len, 2); - return src[len - 1] != kBase64Pad ? 0 : src[len - 2] != kBase64Pad ? 1 : 2; - } - - static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table); - template static std::string encodeImpl(const T& data, const Charset& charset, bool include_pad); @@ -122,8 +95,18 @@ class Base64 { size_t src_len, char* dst, size_t dst_len, - const ReverseIndex& table, - bool include_pad); + const ReverseIndex& table); + + public: + // Padding character used in encoding + constexpr static char kBase = 64; + + private: + // Size of the binary block before encoding. + constexpr static int kBinaryBlockSize = 3; + + // Size of the encoded block after encoding. + constexpr static int kEncodedBlockSize = 4; }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index d9918d53b59c5..39e5f59b4dc0b 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_encode Base64.cpp) -target_link_libraries(velox_encode PUBLIC Folly::folly) +add_library(velox_encode Base.cpp Base64.cpp) +target_link_libraries(velox_encode PUBLIC Folly::folly) \ No newline at end of file diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 35648c8921234..d52a34c99fede 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -318,8 +318,9 @@ struct FromBase64Function { auto inputSize = input.size(); result.resize( encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode(input.data(), input.size(), result.data()); - } catch (const encoding::Base64Exception& e) { + encoding::Base64::decode( + input.data(), inputSize, result.data(), result.size()); + } catch (const encoding::BaseException& e) { VELOX_USER_FAIL(e.what()); } } @@ -332,15 +333,11 @@ struct FromBase64UrlFunction { FOLLY_ALWAYS_INLINE void call( out_type& result, const arg_type& input) { - auto inputData = input.data(); auto inputSize = input.size(); - bool hasPad = - inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad); result.resize( - encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad)); - hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad. + encoding::Base64::calculateDecodedSize(input.data(), inputSize)); encoding::Base64::decodeUrl( - inputData, inputSize, result.data(), result.size(), hasPad); + input.data(), inputSize, result.data(), result.size()); } }; diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index 47a1e67dfedb8..817bf0e0b19af 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -424,11 +424,16 @@ TEST_F(BinaryFunctionsTest, fromBase64) { EXPECT_EQ(std::nullopt, fromBase64(std::nullopt)); EXPECT_EQ("", fromBase64("")); EXPECT_EQ("a", fromBase64("YQ==")); + EXPECT_EQ("ab", fromBase64("YWI=")); EXPECT_EQ("abc", fromBase64("YWJj")); EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ=")); EXPECT_EQ( "Hello World from Velox!", fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE=")); + // Check encoded strings without padding + EXPECT_EQ("a", fromBase64("YQ")); + EXPECT_EQ("ab", fromBase64("YWI")); + EXPECT_EQ("abcd", fromBase64("YWJjZA")); EXPECT_THROW(fromBase64("YQ="), VeloxUserError); EXPECT_THROW(fromBase64("YQ==="), VeloxUserError);