From 4526010233ba310f1d089abc6da6039ef92f0ee1 Mon Sep 17 00:00:00 2001 From: "Schierbeck, Cody" Date: Fri, 19 Jan 2024 10:45:27 -0800 Subject: [PATCH] Initial sparksql decode and encode function implementations --- velox/docs/functions/spark/string.rst | 18 +- velox/functions/sparksql/Register.cpp | 5 + velox/functions/sparksql/String.h | 375 ++++++++++++++++++ velox/functions/sparksql/tests/StringTest.cpp | 370 +++++++++++++++++ 4 files changed, 767 insertions(+), 1 deletion(-) diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 817eaad34ab2..1cb67a03b62e 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -306,4 +306,20 @@ Unless specified otherwise, all functions return NULL if at least one of the arg Returns string with all characters changed to uppercase. :: - SELECT upper('SparkSql'); -- SPARKSQL \ No newline at end of file + SELECT upper('SparkSql'); -- SPARKSQL + +.. spark:function:: decode(bin, charset) -> varchar + + Decodes the binary into a string using the provided charset. + Supported charsets: UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, US-ASCII. + Throws VeloxUserError for conversion errors. :: + + SELECT decode('48656C6C6F20576F726C64', "utf-8"); -- "Hello World" + +.. spark:function:: encode(string, charset) -> varbinary + + Encodes the string into a binary representation using the provided charset. + Supported charsets: UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, US-ASCII. + Throws VeloxUserError for conversion errors. :: + + SELECT decode('Hello World', "utf-8"); -- "48656C6C6F20576F726C64" \ No newline at end of file diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 01049c0ab59b..3f945f695826 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -264,6 +264,11 @@ void registerFunctions(const std::string& prefix) { registerFunction( {prefix + "conv"}); + registerFunction( + {prefix + "decode"}); + registerFunction( + {prefix + "encode"}); + registerFunction( {prefix + "replace"}); registerFunction( diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index 75f9e90f3d3b..7f4222d22a49 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -20,6 +20,8 @@ #include #include #include "velox/expression/VectorFunction.h" +#include "velox/expression/VectorReaders.h" +#include "velox/external/utf8proc/utf8procImpl.h" #include "velox/functions/Macros.h" #include "velox/functions/UDFOutputString.h" #include "velox/functions/lib/string/StringCore.h" @@ -1181,4 +1183,377 @@ struct FindInSetFunction { } }; +template +class CharsetFunctionBase { + protected: + VELOX_DEFINE_FUNCTION_TYPES(T); + enum class Charset { + UTF8, + UTF16BE, + UTF16LE, + ISO88591, + ASCII, + UTF16, + Unsupported + }; + + Charset getCharsetEnum(const arg_type& charset) { + if (!facebook::velox::functions::stringCore::isAscii( + charset.data(), charset.size())) { + VELOX_USER_FAIL( + "Unsupported encoding: {}. Only UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, and US-ASCII are supported.", + charset); + } + auto it = charsetMap.find(toLower(charset.str())); + if (it != charsetMap.end()) { + return it->second; + } + return Charset::Unsupported; + } + Charset currentCharset_; + + private: + std::string toLower(const std::string& input) { + std::string lowercase; + lowercase.resize(input.size()); + facebook::velox::functions::stringCore::lowerAscii( + lowercase.data(), input.data(), input.size()); + return lowercase; + } + const std::map charsetMap = { + {"utf-8", Charset::UTF8}, + {"utf-16be", Charset::UTF16BE}, + {"utf-16le", Charset::UTF16LE}, + {"iso-8859-1", Charset::ISO88591}, + {"us-ascii", Charset::ASCII}, + {"utf-16", Charset::UTF16}}; +}; + +/// ENCODE(string, charset) -> varbinary +/// +/// Encodes the string into a binary using the provided charset. +/// +/// Supported charsets: UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, US-ASCII. +/// Throws VeloxUserError for conversion errors. +template +struct EncodeFunction : public CharsetFunctionBase { + VELOX_DEFINE_FUNCTION_TYPES(T); + using Charset = typename CharsetFunctionBase::Charset; + + void call( + out_type& result, + const arg_type& string, + const arg_type& charset) { + if (string.empty()) { + return; + } + + Charset charsetEnum = this->getCharsetEnum(charset); + switch (charsetEnum) { + case Charset::UTF16BE: + case Charset::UTF16LE: + case Charset::UTF16: + try { + convertToUTF16(result, string, charsetEnum); + return; + } catch (const std::range_error& e) { + VELOX_USER_FAIL("Invalid UTF-16 string"); + } + case Charset::UTF8: + result += stringToHex(string.str()); + return; + case Charset::ISO88591: + convertToISO88591(result, string); + return; + case Charset::ASCII: + convertToASCII(result, string); + break; + default: + VELOX_USER_FAIL( + "Unsupported encoding: {}. Only UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, and US-ASCII are supported.", + charset); + } + } + + private: + /// Used to convert a character of either 32, 16, or 8 bits (wchar_t, + /// char16_t, ANY & 0xFF) to a hex string. + template + static std::string charToHex(CharType input) { + std::stringstream hexStream; + hexStream << std::hex << std::uppercase << std::setfill('0'); + + if constexpr (std::is_same::value) { + hexStream << std::setw(4) << static_cast(input); + } else if constexpr (std::is_same::value) { + hexStream << std::setw(sizeof(wchar_t) * 2) + << static_cast(input); + } else { + hexStream << std::setw(2) << (static_cast(input) & 0xff); + } + + return hexStream.str(); + } + + /// Used to convert a string of either 16 or 8 bits (std::u16string, ANY + /// iterable & 0xFF) to a hex string. + template + std::string stringToHex(const StringType& input) { + std::stringstream hexStream; + hexStream << std::hex << std::uppercase << std::setfill('0'); + + if constexpr (std::is_same::value) { + // Handle std::u16string (16-bit characters) + for (char16_t ch : input) { + hexStream << std::setw(4) << static_cast(ch); + } + } else { + // Handle std::string (8-bit characters) + for (auto byte : input) { + hexStream << std::setw(2) << (static_cast(byte) & 0xff); + } + } + + return hexStream.str(); + } + /// Used to convert a single codepoint to a hex string. + template + std::string codePointToHex(const codePointType& codepoint) { + if constexpr (std::is_same::value) { + if (codepoint < 0x10000) { + return charToHex(static_cast(codepoint)); + } else { + int32_t temp = codepoint - 0x10000; + return charToHex(static_cast((temp >> 10) + 0xD800)) + + charToHex(static_cast((temp & 0x3FF) + 0xDC00)); + } + } else if constexpr (std::is_same::value) { + // Directly handle char16_t without additional checks or conversions. + return charToHex(codepoint); + } else { + VELOX_USER_FAIL("Invalid codepoint type") + } + } + + /// Converts a string to UTF-16, UTF-16BE, or UTF-16LE. + /// Spark defautls UTF-16 to UTF-16BE, so only + /// UTF-16LE requires additional work. + void convertToUTF16( + out_type& result, + const arg_type& string, + const Charset& charset) { + result.reserve(string.size() * 2 + 4); + if (charset == Charset::UTF16) { + // Spark uses Scala/Java which output BE + // by default + result += "FEFF"; + } + const char* str = string.data(); + size_t length = string.size(); + size_t position = 0; + while (position < length) { + int charLength; + int32_t codepoint = + utf8proc_codepoint(str + position, str + length, charLength); + if (codepoint <= 0xFFFF) { + if (charset == Charset::UTF16LE) { + codepoint = ((codepoint >> 8) | (codepoint << 8)) & 0xFFFF; + } + result += codePointToHex(codepoint); + } else { + codepoint -= 0x10000; + char16_t highSurrogate = ((codepoint >> 10) + 0xD800); + char16_t lowSurrogate = ((codepoint & 0x3FF) + 0xDC00); + + if (charset == Charset::UTF16LE) { + highSurrogate = ((highSurrogate >> 8) | (highSurrogate << 8)); + lowSurrogate = ((lowSurrogate >> 8) | (lowSurrogate << 8)); + } + result += codePointToHex(highSurrogate) + codePointToHex(lowSurrogate); + } + + position += charLength; + } + } + + void convertToISO88591( + out_type& result, + const arg_type utf8String) { + int i = 0; + while (i < utf8String.size()) { + const unsigned char curr = *(utf8String.begin() + i); + if (curr < 0x80) { + result += charToHex(curr); + ++i; + } else if (curr >= 0xC2 && curr <= 0xC3 && i + 1 < utf8String.size()) { + const unsigned char next = *(utf8String.begin() + i + 1); + if (curr == 0xC2) { + result += charToHex(next); + } else if (curr == 0xC3) { + result += charToHex(next + 0x40); + } + i += 2; + } else { + VELOX_USER_FAIL("Invalid character for ISO-8859-1 encoding.") + } + } + } + void convertToASCII( + out_type& result, + const arg_type& input) { + int i = 0; + while (i < input.size()) { + const unsigned char curr = *(input.begin() + i); + if (curr < 0x7F) { + result += charToHex(*(input.begin() + i)); + ++i; + } else { + VELOX_USER_FAIL("Invalid character for US-ASCII encoding.") + } + } + } +}; + +/// DECODE(bin, charset) -> varchar +//// +/// Decodes the binary into a string using the provided charset. +/// Supported charsets: UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, US-ASCII. +/// +/// Throws VeloxUserError for conversion errors. +template +struct DecodeFunction : public CharsetFunctionBase { + VELOX_DEFINE_FUNCTION_TYPES(T); + using Charset = typename CharsetFunctionBase::Charset; + + void call( + out_type& result, + const arg_type& bin, + const arg_type& encoding) { + if (bin.empty()) { + return; + } + Charset charsetEnum = this->getCharsetEnum(encoding); + switch (charsetEnum) { + case Charset::UTF16: + case Charset::UTF16LE: + case Charset::UTF16BE: + this->currentCharset_ = charsetEnum; + decodeFromUTF16(bin, result); + return; + case Charset::UTF8: + decodeFromUTF8(bin, result); + return; + case Charset::ISO88591: + decodeFromISO(bin, result); + return; + case Charset::ASCII: + decodeFromASCII(bin, result); + return; + default: + VELOX_USER_FAIL( + "Unsupported encoding: {}. Only UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, and US-ASCII are supported.", + encoding); + } + } + + private: + void decodeFromUTF16( + const arg_type& bin, + out_type& result) { + uint16_t highSurrogatePair = 0; + uint16_t bom = 0; + int numChars = 0; + size_t index; + result.reserve(bin.size() / 4 + 1); + for (index = 0; index < bin.size() - 3; index += 4) { + uint16_t bytePair; + std::sscanf(bin.data() + index, "%04hx", &bytePair); + if (index == 0 && (bytePair == 0xFFFE || bytePair == 0xFEFF)) { + bom = bytePair; + continue; + } + if (this->currentCharset_ == Charset::UTF16LE || bom == 0xFFFE) { + bytePair = ((bytePair >> 8) | (bytePair << 8)); + } + // Look for a low-surrogate + if (highSurrogatePair != 0) { + // We are in the middle of a surrogate pair + if (bytePair >= 0xDC00 && bytePair <= 0xDFFF) { + uint32_t codePoint = ((highSurrogatePair - 0xD800) << 10) + + (bytePair - 0xDC00) + 0x10000; + numChars += utf8proc_encode_char( + codePoint, + reinterpret_cast(result.data() + numChars)); + + highSurrogatePair = 0; + } else { + VELOX_USER_FAIL("Invalid UTF-16 surrogate pair"); + } + } + // Encountered a high-surrogate + else if (bytePair >= 0xD800 && bytePair <= 0xDBFF) { + highSurrogatePair = bytePair; + continue; + } else { + // Not a surrogate pair + numChars += utf8proc_encode_char( + bytePair, + reinterpret_cast(result.data() + numChars)); + } + } + VELOX_USER_CHECK_EQ(highSurrogatePair, 0, "Unpaired UTF-16 high surrogate"); + VELOX_USER_CHECK_EQ(index, bin.size(), "Improperly sized UTF-16 string"); + result.resize(numChars); + } + + void decodeFromUTF8( + const arg_type& bin, + out_type& result) { + for (size_t i = 0; i < bin.size(); i += 2) { + uint16_t byte; + std::sscanf(bin.data() + i, "%02hx", &byte); + const char temp[2] = {static_cast(byte), '\0'}; + result += temp; + } + } + + void decodeFromISO( + const arg_type& bin, + out_type& result) { + int numChars = 0; + result.reserve(bin.size() / 2 + 1); + for (int i = 0; i < bin.size(); i += 2) { + uint16_t byte; + std::sscanf(bin.data() + i, "%02hx", &byte); + VELOX_USER_CHECK_LE( + byte, 0xFF, "Invalid character for ISO-8859-1 encoding."); + if (byte >= 0x80) { + if (byte < 0xC0) { + result.data()[numChars++] = 0xC2; + } else { + result.data()[numChars++] = 0xC3; + byte -= 0x40; + } + } + result.data()[numChars++] = byte; + } + result.resize(numChars); + } + + void decodeFromASCII( + const arg_type bin, + out_type& result) { + int numChars = 0; + result.reserve(bin.size() / 2 + 1); + for (size_t i = 0; i < bin.size(); i += 2) { + unsigned int byte; + std::sscanf(bin.data() + i, "%02x", &byte); + VELOX_USER_CHECK_LE( + byte, 0x7F, "Invalid character for US-ASCII encoding.") + result.data()[numChars++] = byte; + } + result.resize(numChars); + } +}; + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/StringTest.cpp b/velox/functions/sparksql/tests/StringTest.cpp index 60bf3516bf06..9125f01c62df 100644 --- a/velox/functions/sparksql/tests/StringTest.cpp +++ b/velox/functions/sparksql/tests/StringTest.cpp @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "boost/endian.hpp" +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" #include "velox/type/Type.h" #include +#include namespace facebook::velox::functions::sparksql::test { namespace { @@ -26,6 +29,286 @@ class StringTest : public SparkFunctionBaseTest { // This is a five codepoint sequence that renders as a single emoji. static constexpr char kWomanFacepalmingLightSkinTone[] = "\xF0\x9F\xA4\xA6\xF0\x9F\x8F\xBB\xE2\x80\x8D\xE2\x99\x80\xEF\xB8\x8F"; + std::string bom = boost::endian::order::native == boost::endian::order::big + ? "FEFF" + : "FFFE"; + std::map>> + encodeDecodeTestCases = { + {"utf-8", + {{"48656C6C6F20576F726C64", "Hello World"}, + {"", ""}, + {"E298BA", "☺"}, + {"F09F9881", "😁"}}}, + {"iso-8859-1", + {{"48656C6C6F20576F726C64", "Hello World"}, + {"A1", "¡"}, + {"", ""}, + {"E7F364FD2073E768ECEB7262E8E76B", "çódý sçhìërbèçk"}}}, + {"us-ascii", + {{"48656C6C6F20576F726C64", "Hello World"}, {"7E", "~"}, {"", ""}}}, + {"utf-16be", + {{"00480065006C006C006F00200057006F0072006C0064", "Hello World"}, + {"004100420043", "ABC"}, + {"D83DDE02", "😂"}, + {"266B00A100530069006E00670069006E0067002000690073002000660075006E0021266B", + "♫¡Singing is fun!♫"}}}, + {"utf-16le", + {{"480065006C006C006F00200057006F0072006C006400", "Hello World"}, + {"410042004300", "ABC"}, + {"", ""}, + {"3DD802DE", "😂"}, + {"6B26A100530069006E00670069006E0067002000690073002000660075006E0021006B26", + "♫¡Singing is fun!♫"}}}, + {"utf-16", + {{"FEFF00480065006C006C006F00200057006F0072006C0064", "Hello World"}, + {"FEFFD83DDE02", "😂"}, + {"", ""}, + {"FEFF266B00A100530069006E00670069006E0067002000690073002000660075006E0021266B", + "♫¡Singing is fun!♫"}}}}; + + std::string generateRandomString(size_t length) { + const std::string characters = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + std::random_device random_device; + std::mt19937 generator(random_device()); + std::uniform_int_distribution<> distribution(0, characters.size() - 1); + + std::string random_string; + for (size_t i = 0; i < length; ++i) { + random_string += characters[distribution(generator)]; + } + + return random_string; + } + + std::optional decodeString( + std::optional binary, + std::optional encoding) { + return evaluateOnce( + "decode(c0, c1)", {binary, encoding}, {VARBINARY(), VARCHAR()}); + } + + std::optional encodeString( + std::optional string, + std::optional encoding) { + return evaluateOnce( + "encode(c0, c1)", {string, encoding}, {VARCHAR(), VARCHAR()}); + } + std::optional encodeDecode( + std::optional binary, + std::optional encoding) { + return evaluateOnce( + "encode(decode(c0, c1), c1)", + {binary, encoding}, + {VARBINARY(), VARCHAR()}); + } + + std::optional decodeEncode( + std::optional string, + std::optional encoding) { + return evaluateOnce( + "decode(encode(c0, c1), c1)", + {string, encoding}, + {VARCHAR(), VARCHAR()}); + } + std::optional ascii(std::optional arg) { + return evaluateOnce("ascii(c0)", arg); + } + + std::optional chr(std::optional arg) { + return evaluateOnce("chr(c0)", arg); + } + + std::optional instr( + std::optional haystack, + std::optional needle) { + return evaluateOnce("instr(c0, c1)", haystack, needle); + } + + std::optional length(std::optional arg) { + return evaluateOnce("length(c0)", arg); + } + + std::optional length_bytes(std::optional arg) { + return evaluateOnce( + "length(c0)", {arg}, {VARBINARY()}); + } + + std::optional trim(std::optional srcStr) { + return evaluateOnce("trim(c0)", srcStr); + } + + std::optional trim( + std::optional trimStr, + std::optional srcStr) { + return evaluateOnce("trim(c0, c1)", trimStr, srcStr); + } + + std::optional ltrim(std::optional srcStr) { + return evaluateOnce("ltrim(c0)", srcStr); + } + + std::optional ltrim( + std::optional trimStr, + std::optional srcStr) { + return evaluateOnce("ltrim(c0, c1)", trimStr, srcStr); + } + + std::optional rtrim(std::optional srcStr) { + return evaluateOnce("rtrim(c0)", srcStr); + } + + std::optional rtrim( + std::optional trimStr, + std::optional srcStr) { + return evaluateOnce("rtrim(c0, c1)", trimStr, srcStr); + } + + std::optional md5(std::optional arg) { + return evaluateOnce( + "md5(c0)", {arg}, {VARBINARY()}); + } + + std::optional sha1(std::optional arg) { + return evaluateOnce( + "sha1(c0)", {arg}, {VARBINARY()}); + } + + std::optional sha2( + std::optional str, + std::optional bitLength) { + return evaluateOnce( + "sha2(cast(c0 as varbinary), c1)", str, bitLength); + } + + bool compareFunction( + const std::string& function, + const std::optional& str, + const std::optional& pattern) { + return evaluateOnce(function + "(c0, c1)", str, pattern).value(); + } + + std::optional startsWith( + const std::optional& str, + const std::optional& pattern) { + return evaluateOnce("startsWith(c0, c1)", str, pattern); + } + std::optional endsWith( + const std::optional& str, + const std::optional& pattern) { + return evaluateOnce("endsWith(c0, c1)", str, pattern); + } + std::optional contains( + const std::optional& str, + const std::optional& pattern) { + return evaluateOnce("contains(c0, c1)", str, pattern); + } + + std::optional substring( + std::optional str, + std::optional start) { + return evaluateOnce("substring(c0, c1)", str, start); + } + + std::optional substring( + std::optional str, + std::optional start, + std::optional length) { + return evaluateOnce( + "substring(c0, c1, c2)", str, start, length); + } + + std::optional left( + std::optional str, + std::optional length) { + return evaluateOnce("left(c0, c1)", str, length); + } + + std::optional substringIndex( + const std::string& str, + const std::string& delim, + int32_t count) { + return evaluateOnce( + "substring_index(c0, c1, c2)", str, delim, count); + } + + std::optional overlay( + std::optional input, + std::optional replace, + std::optional pos, + std::optional len) { + // overlay is a keyword of DuckDB, use double quote avoid parse error. + return evaluateOnce( + "\"overlay\"(c0, c1, c2, c3)", input, replace, pos, len); + } + + std::optional overlayVarbinary( + std::optional input, + std::optional replace, + std::optional pos, + std::optional len) { + // overlay is a keyword of DuckDB, use double quote avoid parse error. + return evaluateOnce( + "\"overlay\"(cast(c0 as varbinary), cast(c1 as varbinary), c2, c3)", + input, + replace, + pos, + len); + } + std::optional rpad( + std::optional string, + std::optional size, + std::optional padString) { + return evaluateOnce( + "rpad(c0, c1, c2)", string, size, padString); + } + + std::optional lpad( + std::optional string, + std::optional size, + std::optional padString) { + return evaluateOnce( + "lpad(c0, c1, c2)", string, size, padString); + } + + std::optional rpad( + std::optional string, + std::optional size) { + return evaluateOnce("rpad(c0, c1)", string, size); + } + + std::optional lpad( + std::optional string, + std::optional size) { + return evaluateOnce("lpad(c0, c1)", string, size); + } + + std::optional conv( + std::optional str, + std::optional fromBase, + std::optional toBase) { + return evaluateOnce("conv(c0, c1, c2)", str, fromBase, toBase); + } + + std::optional replace( + std::optional str, + std::optional replaced) { + return evaluateOnce("replace(c0, c1)", str, replaced); + } + + std::optional replace( + std::optional str, + std::optional replaced, + std::optional replacement) { + return evaluateOnce( + "replace(c0, c1, c2)", str, replaced, replacement); + } + + std::optional findInSet( + std::optional str, + std::optional strArray) { + return evaluateOnce("find_in_set(c0, c1)", str, strArray); + } }; TEST_F(StringTest, ascii) { @@ -862,5 +1145,92 @@ TEST_F(StringTest, trim) { trimWithTrimStr("\u6570", "\u6574\u6570 \u6570\u636E!"), "\u6574\u6570 \u6570\u636E!"); } +TEST_F(StringTest, decodeString) { + for (const auto& testCase : encodeDecodeTestCases) { + const auto& encoding = testCase.first; + const auto& pairs = testCase.second; + + for (const auto& pair : pairs) { + std::optional encodedString(pair.first); + std::optional expectedDecodedString(pair.second); + + EXPECT_EQ(decodeString(encodedString, encoding), expectedDecodedString); + } + } +} + +TEST_F(StringTest, encodeString) { + for (const auto& testCase : encodeDecodeTestCases) { + const auto& encoding = testCase.first; + const auto& pairs = testCase.second; + + for (const auto& pair : pairs) { + std::optional expectedEncodedString(pair.first); + std::optional string(pair.second); + + EXPECT_EQ(encodeString(string, encoding), expectedEncodedString); + } + } +} + +TEST_F(StringTest, encodeDecode) { + for (const auto& testCase : encodeDecodeTestCases) { + const auto& encoding = testCase.first; + const auto& pairs = testCase.second; + + for (const auto& pair : pairs) { + EXPECT_EQ(encodeDecode(pair.first, encoding), pair.first); + } + } +} + +TEST_F(StringTest, decodeEncode) { + for (const auto& testCase : encodeDecodeTestCases) { + const auto& encoding = testCase.first; + const auto& pairs = testCase.second; + + for (const auto& pair : pairs) { + EXPECT_EQ(decodeEncode(pair.second, encoding), pair.second); + } + } +} + +TEST_F(StringTest, randomEncodeDecode) { + for (int i = 0; i < 2000; i++) { + std::string randomString = generateRandomString(200); + EXPECT_EQ(decodeEncode(randomString, "UTF-8"), randomString); + } +} + +TEST_F(StringTest, encodeErrors) { + std::string invalidString = "Ψ\xFF\xFFΣΓΔA"; + std::string invalidASCII = "😀"; + std::string invalidEncoding = "UTF-84"; + + VELOX_ASSERT_THROW( + encodeString(invalidASCII, "us-ascii"), + "Invalid character for US-ASCII encoding."); + VELOX_ASSERT_THROW( + encodeString(invalidASCII, "iso-8859-1"), + "Invalid character for ISO-8859-1 encoding."); + VELOX_ASSERT_THROW( + encodeString(invalidASCII, invalidEncoding), + "Unsupported encoding: UTF-84. Only UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, and US-ASCII are supported."); +} + +TEST_F(StringTest, decodeErrors) { + std::string invalidString = "Ψ\xFF\xFFΣΓΔA"; + std::string invalidEncoding = "UTF-84"; + VELOX_ASSERT_THROW( + decodeString(invalidString, "us-ascii"), + "Invalid character for US-ASCII encoding."); + VELOX_ASSERT_THROW( + decodeString(invalidString, "iso-8859-1"), + "Invalid character for ISO-8859-1 encoding."); + VELOX_ASSERT_THROW( + decodeString(invalidString, invalidEncoding), + "Unsupported encoding: UTF-84. Only UTF-8, UTF-16, UTF-16BE, UTF-16LE, ISO-8859-1, and US-ASCII are supported."); +} + } // namespace } // namespace facebook::velox::functions::sparksql::test