From 6fe67de9d3462d0c442081aa84b603e58f9cb304 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Thu, 21 Nov 2024 13:23:07 -0800 Subject: [PATCH] feat(function): Handle unescaped UTF-8 characters in Presto url_extract_* UDFs (#11535) Summary: Presto Java supports UTF-8 characters that are not control or whitespace characters appearing anywhere in a URL where a % escaped character can appear. This change modifies Velox's URIParser to do the same. Velox's URIParser would produce incorrect results when any non-ASCII character appeared anywhere in the URL and this has been fixed as well. In order to facilitate this I modified the tryGetCharLength helper function in UTF8Utils to take in a int32_t reference which it populates with the code point if the UTF-8 character is valid. It was already calculating this value and throwing it away, returning it allows me to avoid an additional call to repeat those steps and is consistent with the Airlift function on which it's based. Reviewed By: xiaoxmeng, kgpai Differential Revision: D65927918 --- velox/functions/lib/Utf8Utils.cpp | 14 ++-- velox/functions/lib/Utf8Utils.h | 5 +- velox/functions/lib/tests/Utf8Test.cpp | 53 ++++++++------- velox/functions/prestosql/FromUtf8.cpp | 15 +++-- velox/functions/prestosql/URIParser.cpp | 54 +++++++++++---- velox/functions/prestosql/URLFunctions.h | 23 +++---- .../prestosql/tests/URLFunctionsTest.cpp | 66 +++++++++++++++++-- velox/functions/sparksql/Split.h | 12 ++-- 8 files changed, 176 insertions(+), 66 deletions(-) diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 17a26a633f5f..02354db6cbc1 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -61,10 +61,15 @@ int firstByteCharLength(const char* u_input) { } // namespace -int32_t tryGetCharLength(const char* input, int64_t size) { +int32_t +tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint) { VELOX_DCHECK_NOT_NULL(input); VELOX_DCHECK_GT(size, 0); + // Set codePoint to an impossible value so it's obvious if anyone forgets to + // check the return value before using it. + codePoint = -1; + auto charLength = firstByteCharLength(input); if (charLength < 0) { return -1; @@ -72,6 +77,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 1) { // Normal ASCII: 0xxx_xxxx. + codePoint = input[0]; return 1; } @@ -89,7 +95,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 2) { // 110x_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); + codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); // Fail if overlong encoding. return codePoint < 0x80 ? -2 : 2; } @@ -106,7 +112,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 3) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00001111) << 12) | + codePoint = ((firstByte & 0b00001111) << 12) | ((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111); // Surrogates are invalid. @@ -132,7 +138,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 4) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00000111) << 18) | + codePoint = ((firstByte & 0b00000111) << 18) | ((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) | (forthByte & 0b00111111); // Fail if overlong encoding or above upper bound of Unicode. diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 369e1151e93f..eb994219ec36 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -45,12 +45,15 @@ namespace facebook::velox::functions { /// /// @param input Pointer to the first byte of the code point. Must not be null. /// @param size Number of available bytes. Must be greater than zero. +/// @param codePoint Populated with the code point it refers to. This is only +/// valid if the return value is positive. /// @return the length of the code point or negative the number of bytes in the /// invalid UTF-8 sequence. /// /// Adapted from tryGetCodePointAt in /// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java -int32_t tryGetCharLength(const char* input, int64_t size); +int32_t +tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint); /// Return the length in byte of the next UTF-8 encoded character at the /// beginning of `string`. If the beginning of `string` is not valid UTF-8 diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 4330d1d9bbd2..48a463a2a08e 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -21,53 +21,62 @@ namespace facebook::velox::functions { namespace { TEST(Utf8Test, tryCharLength) { + int32_t codepoint; // Single-byte ASCII character. - ASSERT_EQ(1, tryGetCharLength("Hello", 5)); + ASSERT_EQ(1, tryGetUtf8CharLength("Hello", 5, codepoint)); + ASSERT_EQ('H', codepoint); // 2-byte character. British pound sign. static const char* kPound = "\u00A3tail"; - ASSERT_EQ(2, tryGetCharLength(kPound, 5)); + ASSERT_EQ(2, tryGetUtf8CharLength(kPound, 5, codepoint)); + ASSERT_EQ(0xA3, codepoint); // First byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound, 1)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kPound, 1, codepoint)); // Second byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kPound + 1, 5, codepoint)); // ASCII character 't' after the pound sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kPound + 2, 5, codepoint)); // 3-byte character. Euro sign. static const char* kEuro = "\u20ACtail"; - ASSERT_EQ(3, tryGetCharLength(kEuro, 5)); + ASSERT_EQ(3, tryGetUtf8CharLength(kEuro, 5, codepoint)); + ASSERT_EQ(0x20AC, codepoint); // First byte or first 2 bytes alone are not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro, 1)); - ASSERT_EQ(-2, tryGetCharLength(kEuro, 2)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro, 1, codepoint)); + ASSERT_EQ(-2, tryGetUtf8CharLength(kEuro, 2, codepoint)); // Byte sequence starting from 2nd or 3rd byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5)); - ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5)); - ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro + 1, 5, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro + 2, 5, codepoint)); // ASCII character 't' after the euro sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kEuro + 3, 5, codepoint)); + ASSERT_EQ('t', codepoint); + ASSERT_EQ(1, tryGetUtf8CharLength(kEuro + 4, 5, codepoint)); + ASSERT_EQ('a', codepoint); // 4-byte character. Musical symbol F CLEF. static const char* kClef = "\U0001D122tail"; - ASSERT_EQ(4, tryGetCharLength(kClef, 5)); + ASSERT_EQ(4, tryGetUtf8CharLength(kClef, 5, codepoint)); + ASSERT_EQ(0x1D122, codepoint); // First byte, first 2 bytes, or first 3 bytes alone are not a valid // character. - ASSERT_EQ(-1, tryGetCharLength(kClef, 1)); - ASSERT_EQ(-2, tryGetCharLength(kClef, 2)); - ASSERT_EQ(-3, tryGetCharLength(kClef, 3)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef, 1, codepoint)); + ASSERT_EQ(-2, tryGetUtf8CharLength(kClef, 2, codepoint)); + ASSERT_EQ(-3, tryGetUtf8CharLength(kClef, 3, codepoint)); // Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 1, 3, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 2, 3, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 3, 3, codepoint)); // ASCII character 't' after the clef sign is valid. - ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kClef + 4, 5, codepoint)); + ASSERT_EQ('t', codepoint); // Test overlong encoding. auto tryCharLength = [](const std::vector& bytes) { - return tryGetCharLength( - reinterpret_cast(bytes.data()), bytes.size()); + int32_t codepoint; + return tryGetUtf8CharLength( + reinterpret_cast(bytes.data()), bytes.size(), codepoint); }; // 2-byte encoding of 0x2F. diff --git a/velox/functions/prestosql/FromUtf8.cpp b/velox/functions/prestosql/FromUtf8.cpp index c538db022961..b93df3744d5b 100644 --- a/velox/functions/prestosql/FromUtf8.cpp +++ b/velox/functions/prestosql/FromUtf8.cpp @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction { auto replacement = decoded.valueAt(row); if (!replacement.empty()) { - auto charLength = - tryGetCharLength(replacement.data(), replacement.size()); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + replacement.data(), replacement.size(), codePoint); VELOX_USER_CHECK_GT( charLength, 0, "Replacement is not a valid UTF-8 character"); VELOX_USER_CHECK_EQ( @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < value.size()) { - auto charLength = - tryGetCharLength(value.data() + pos, value.size() - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + value.data() + pos, value.size() - pos, codePoint); if (charLength < 0) { firstInvalidRow = row; return false; @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < input.size()) { - auto charLength = - tryGetCharLength(input.data() + pos, input.size() - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + input.data() + pos, input.size() - pos, codePoint); if (charLength > 0) { fixedWriter.append(std::string_view(input.data() + pos, charLength)); pos += charLength; diff --git a/velox/functions/prestosql/URIParser.cpp b/velox/functions/prestosql/URIParser.cpp index d15d98a8ef38..04a45791481f 100644 --- a/velox/functions/prestosql/URIParser.cpp +++ b/velox/functions/prestosql/URIParser.cpp @@ -15,6 +15,8 @@ */ #include "velox/functions/prestosql/URIParser.h" +#include "velox/external/utf8proc/utf8procImpl.h" +#include "velox/functions/lib/Utf8Utils.h" namespace facebook::velox::functions { @@ -44,6 +46,11 @@ Mask createMask(const std::vector& values) { return mask; } + +bool test(const Mask& mask, char value) { + return value < mask.size() && mask.test(value); +} + // a-z or A-Z. const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z'); // 0-9. @@ -135,7 +142,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { return false; } - if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) { + if (str[pos] != '%' || !test(kHex, str[pos + 1]) || + !test(kHex, str[pos + 2])) { return false; } @@ -145,7 +153,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { } // Helper function that consumes as much of `str` from `pos` as possible where a -// character passes mask or is part of a percent encoded character. +// character passes mask, is part of a percent encoded character, or is an +// allowed UTF-8 character. // // `pos` is updated to the first character in `str` that was not consumed and // `hasEncoded` is set to true if any percent encoded characters were @@ -157,7 +166,7 @@ void consume( int32_t& pos, bool& hasEncoded) { while (pos < len) { - if (mask.test(str[pos])) { + if (test(mask, str[pos])) { pos++; continue; } @@ -167,6 +176,29 @@ void consume( continue; } + // Masks cover all ASCII characters, check if this is an allowed UTF-8 + // character. + if ((unsigned char)str[pos] > 127) { + // Get the UTF-8 code point. + int32_t codePoint; + auto valid = tryGetUtf8CharLength(str + pos, len - pos, codePoint); + + // Check if it's a valid UTF-8 character. + // The range after ASCII characters up to 159 covers control characters + // which are not allowed. + if (valid > 0 && codePoint > 159) { + const auto category = utf8proc_get_property(codePoint)->category; + // White space characters are also not allowed. The range of categories + // excluded here are categories of white space. + if (category < UTF8PROC_CATEGORY_ZS || + category > UTF8PROC_CATEGORY_ZP) { + // Increment over the whole (potentially multi-byte) character. + pos += valid; + continue; + } + } + } + break; } } @@ -314,7 +346,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { while (posInAddress < len && numBytes < 16) { int32_t posInHex = posInAddress; for (int i = 0; i < 4; i++) { - if (posInHex == len || !kHex.test(str[posInHex])) { + if (posInHex == len || !test(kHex, str[posInHex])) { break; } @@ -350,7 +382,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { posInAddress = posInHex + 2; } } else { - if (posInHex == len || !kHex.test(str[posInHex + 1])) { + if (posInHex == len || !test(kHex, str[posInHex + 1])) { // Peak ahead, we can't end on a single ':'. return false; } @@ -392,7 +424,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { // Consume a string of hex digits. int32_t posInHex = posInAddress; while (posInHex < len) { - if (kHex.test(str[posInHex])) { + if (test(kHex, str[posInHex])) { posInHex++; } else { break; @@ -416,7 +448,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { int32_t posInSuffix = posInAddress; while (posInSuffix < len) { - if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) { + if (test(kIPVFutureSuffixOrUserInfo, str[posInSuffix])) { posInSuffix++; } else { break; @@ -467,7 +499,7 @@ void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInPort = pos; while (posInPort < len) { - if (kNum.test(str[posInPort])) { + if (test(kNum, str[posInPort])) { posInPort++; continue; } @@ -488,7 +520,7 @@ void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInIPV4Address = posInHost; if (tryConsumeIPV4Address(str, len, posInIPV4Address) && (posInIPV4Address == len || - kFollowingHost.test(str[posInIPV4Address]))) { + test(kFollowingHost, str[posInIPV4Address]))) { // reg-name and IPv4 addresses are hard to distinguish, a reg-name could // have a valid IPv4 address as a prefix, but treating that prefix as an // IPv4 address would make this URI invalid. We make sure that if we @@ -551,14 +583,14 @@ bool tryConsumeScheme( int32_t posInScheme = pos; // The scheme must start with a letter. - if (posInScheme == len || !kAlpha.test(str[posInScheme])) { + if (posInScheme == len || !test(kAlpha, str[posInScheme])) { return false; } // Consume the first letter. posInScheme++; - while (posInScheme < len && kScheme.test(str[posInScheme])) { + while (posInScheme < len && test(kScheme, str[posInScheme])) { posInScheme++; } diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index f5985f980c7d..4a6421d5bad0 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -80,8 +80,9 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { outputBuffer[outIndex++] = '+'; inputIndex++; } else { - const auto charLength = - tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + int32_t codePoint; + const auto charLength = tryGetUtf8CharLength( + inputBuffer + inputIndex, inputSize - inputIndex, codePoint); if (charLength > 0) { for (int i = 0; i < charLength; ++i) { charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); @@ -93,11 +94,11 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { // According to the Unicode standard the "maximal subpart of an // ill-formed subsequence" is the longest code unit subsequenece that is // either well-formed or of length 1. A replacement character should be - // written for each of these. In practice tryGetCharLength breaks most - // cases into maximal subparts, the exceptions are overlong encodings or - // subsequences outside the range of valid 4 byte sequences. In both - // these cases we should just write out a replacement character for - // every byte in the sequence. + // written for each of these. In practice tryGetUtf8CharLength breaks + // most cases into maximal subparts, the exceptions are overlong + // encodings or subsequences outside the range of valid 4 byte + // sequences. In both these cases we should just write out a + // replacement character for every byte in the sequence. size_t replaceCharactersToWriteOut = 1; if (inputIndex < inputSize - 1) { bool isMultipleInvalidSequences = @@ -108,13 +109,13 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { (inputBuffer[inputIndex] == '\xf0' && (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || // 0xf4 followed by a byte >= 0x90 looks valid to - // tryGetCharLength, but is actually outside the range of valid - // code points. + // tryGetUtf8CharLength, but is actually outside the range of + // valid code points. (inputBuffer[inputIndex] == '\xf4' && (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of - // multi-byte code points to tryGetCharLength, but are not part of - // any valid code point. + // multi-byte code points to tryGetUtf8CharLength, but are not + // part of any valid code point. (unsigned char)inputBuffer[inputIndex] > 0xf4 || inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 0439cc89758e..778bcf2355bf 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -347,7 +347,7 @@ TEST_F(URLFunctionsTest, extractHostRegName) { // Test minimal. EXPECT_EQ("a", extractHost("http://a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=", extractHost( @@ -357,9 +357,31 @@ TEST_F(URLFunctionsTest, extractHostRegName) { "123.456.789.012.abcdefg", extractHost("http://123.456.789.012.abcdefg")); // Test percent encoded. EXPECT_EQ("a b", extractHost("http://a%20b")); + // Valid UTF-8 in host reg name. + EXPECT_EQ("你好", extractHost("https://你好")); + // Valid UTF-8 in userinfo. + EXPECT_EQ("foo", extractHost("https://你好@foo")); - // Invalid character. + // Invalid ASCII character. EXPECT_EQ(std::nullopt, extractHost("http://a b")); + // Inalid UTF-8 in host reg name (it should be a 3 byte character but there's + // only 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8")); + // Inalid UTF-8 in userinfo (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82")); + // Valid UTF-8 in userinfo but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's white + // space: THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84")); + // Valid UTF-8 in userinfo but character is not allowed (it's white space: + // THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84@foo")); } TEST_F(URLFunctionsTest, extractPath) { @@ -380,11 +402,21 @@ TEST_F(URLFunctionsTest, extractPath) { EXPECT_EQ("foo", extractPath("foo")); EXPECT_EQ(std::nullopt, extractPath("BAD URL!")); EXPECT_EQ("", extractPath("http://www.yahoo.com")); - // All valid characters. + // All valid ASCII characters. EXPECT_EQ( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@", extractPath( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@")); + // Valid UTF-8 in path. + EXPECT_EQ("/你好", extractPath("https://foo.com/你好")); + // Inalid UTF-8 in path (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractPort) { @@ -430,11 +462,21 @@ TEST_F(URLFunctionsTest, extractQuery) { EXPECT_EQ("", extractQuery("http://www.yahoo.com?")); // Test non-empty query. EXPECT_EQ("a", extractQuery("http://www.yahoo.com?a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ []", extractQuery( "http://www.yahoo.com?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20[]")); + // Valid UTF-8 in query. + EXPECT_EQ("你好", extractQuery("https://foo.com?你好")); + // Inalid UTF-8 in query (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractFragment) { @@ -442,15 +484,25 @@ TEST_F(URLFunctionsTest, extractFragment) { return evaluateOnce("url_extract_fragment(c0)", url); }; - // Test empty query. + // Test empty fragment. EXPECT_EQ("", extractFragment("http://www.yahoo.com#")); - // Test non-empty query. + // Test non-empty fragment. EXPECT_EQ("a", extractFragment("http://www.yahoo.com#a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ []", extractFragment( "http://www.yahoo.com#abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20[]")); + // Valid UTF-8 in fgrament. + EXPECT_EQ("你好", extractFragment("https://foo.com#你好")); + // Inalid UTF-8 in fragment (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractParameter) { diff --git a/velox/functions/sparksql/Split.h b/velox/functions/sparksql/Split.h index 2cee345f77b2..cb8fcb076700 100644 --- a/velox/functions/sparksql/Split.h +++ b/velox/functions/sparksql/Split.h @@ -81,10 +81,11 @@ struct Split { size_t pos = 0; int32_t count = 0; while (pos < end && count < limit) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid - // character is the absolute value of result of `tryGetCharLength`. + // character is the absolute value of result of `tryGetUtf8CharLength`. charLength = -charLength; } result.add_item().setNoCopy(StringView(start + pos, charLength)); @@ -142,10 +143,13 @@ struct Split { // empty tail string at last, e.g., the result array for split('abc','d|') // is ["a","b","c",""]. if (size == 0) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = + tryGetUtf8CharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid - // character is the absolute value of result of `tryGetCharLength`. + // character is the absolute value of result of + // `tryGetUtf8CharLength`. charLength = -charLength; } offset += charLength;