diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 17a26a633f5f..02354db6cbc1 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -61,10 +61,15 @@ int firstByteCharLength(const char* u_input) { } // namespace -int32_t tryGetCharLength(const char* input, int64_t size) { +int32_t +tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint) { VELOX_DCHECK_NOT_NULL(input); VELOX_DCHECK_GT(size, 0); + // Set codePoint to an impossible value so it's obvious if anyone forgets to + // check the return value before using it. + codePoint = -1; + auto charLength = firstByteCharLength(input); if (charLength < 0) { return -1; @@ -72,6 +77,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 1) { // Normal ASCII: 0xxx_xxxx. + codePoint = input[0]; return 1; } @@ -89,7 +95,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 2) { // 110x_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); + codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); // Fail if overlong encoding. return codePoint < 0x80 ? -2 : 2; } @@ -106,7 +112,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 3) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00001111) << 12) | + codePoint = ((firstByte & 0b00001111) << 12) | ((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111); // Surrogates are invalid. @@ -132,7 +138,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 4) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00000111) << 18) | + codePoint = ((firstByte & 0b00000111) << 18) | ((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) | (forthByte & 0b00111111); // Fail if overlong encoding or above upper bound of Unicode. diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 369e1151e93f..eb994219ec36 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -45,12 +45,15 @@ namespace facebook::velox::functions { /// /// @param input Pointer to the first byte of the code point. Must not be null. /// @param size Number of available bytes. Must be greater than zero. +/// @param codePoint Populated with the code point it refers to. This is only +/// valid if the return value is positive. /// @return the length of the code point or negative the number of bytes in the /// invalid UTF-8 sequence. /// /// Adapted from tryGetCodePointAt in /// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java -int32_t tryGetCharLength(const char* input, int64_t size); +int32_t +tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint); /// Return the length in byte of the next UTF-8 encoded character at the /// beginning of `string`. If the beginning of `string` is not valid UTF-8 diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 4330d1d9bbd2..48a463a2a08e 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -21,53 +21,62 @@ namespace facebook::velox::functions { namespace { TEST(Utf8Test, tryCharLength) { + int32_t codepoint; // Single-byte ASCII character. - ASSERT_EQ(1, tryGetCharLength("Hello", 5)); + ASSERT_EQ(1, tryGetUtf8CharLength("Hello", 5, codepoint)); + ASSERT_EQ('H', codepoint); // 2-byte character. British pound sign. static const char* kPound = "\u00A3tail"; - ASSERT_EQ(2, tryGetCharLength(kPound, 5)); + ASSERT_EQ(2, tryGetUtf8CharLength(kPound, 5, codepoint)); + ASSERT_EQ(0xA3, codepoint); // First byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound, 1)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kPound, 1, codepoint)); // Second byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kPound + 1, 5, codepoint)); // ASCII character 't' after the pound sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kPound + 2, 5, codepoint)); // 3-byte character. Euro sign. static const char* kEuro = "\u20ACtail"; - ASSERT_EQ(3, tryGetCharLength(kEuro, 5)); + ASSERT_EQ(3, tryGetUtf8CharLength(kEuro, 5, codepoint)); + ASSERT_EQ(0x20AC, codepoint); // First byte or first 2 bytes alone are not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro, 1)); - ASSERT_EQ(-2, tryGetCharLength(kEuro, 2)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro, 1, codepoint)); + ASSERT_EQ(-2, tryGetUtf8CharLength(kEuro, 2, codepoint)); // Byte sequence starting from 2nd or 3rd byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5)); - ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5)); - ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro + 1, 5, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kEuro + 2, 5, codepoint)); // ASCII character 't' after the euro sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kEuro + 3, 5, codepoint)); + ASSERT_EQ('t', codepoint); + ASSERT_EQ(1, tryGetUtf8CharLength(kEuro + 4, 5, codepoint)); + ASSERT_EQ('a', codepoint); // 4-byte character. Musical symbol F CLEF. static const char* kClef = "\U0001D122tail"; - ASSERT_EQ(4, tryGetCharLength(kClef, 5)); + ASSERT_EQ(4, tryGetUtf8CharLength(kClef, 5, codepoint)); + ASSERT_EQ(0x1D122, codepoint); // First byte, first 2 bytes, or first 3 bytes alone are not a valid // character. - ASSERT_EQ(-1, tryGetCharLength(kClef, 1)); - ASSERT_EQ(-2, tryGetCharLength(kClef, 2)); - ASSERT_EQ(-3, tryGetCharLength(kClef, 3)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef, 1, codepoint)); + ASSERT_EQ(-2, tryGetUtf8CharLength(kClef, 2, codepoint)); + ASSERT_EQ(-3, tryGetUtf8CharLength(kClef, 3, codepoint)); // Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 1, 3, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 2, 3, codepoint)); + ASSERT_EQ(-1, tryGetUtf8CharLength(kClef + 3, 3, codepoint)); // ASCII character 't' after the clef sign is valid. - ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5)); + ASSERT_EQ(1, tryGetUtf8CharLength(kClef + 4, 5, codepoint)); + ASSERT_EQ('t', codepoint); // Test overlong encoding. auto tryCharLength = [](const std::vector& bytes) { - return tryGetCharLength( - reinterpret_cast(bytes.data()), bytes.size()); + int32_t codepoint; + return tryGetUtf8CharLength( + reinterpret_cast(bytes.data()), bytes.size(), codepoint); }; // 2-byte encoding of 0x2F. diff --git a/velox/functions/prestosql/FromUtf8.cpp b/velox/functions/prestosql/FromUtf8.cpp index c538db022961..b93df3744d5b 100644 --- a/velox/functions/prestosql/FromUtf8.cpp +++ b/velox/functions/prestosql/FromUtf8.cpp @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction { auto replacement = decoded.valueAt(row); if (!replacement.empty()) { - auto charLength = - tryGetCharLength(replacement.data(), replacement.size()); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + replacement.data(), replacement.size(), codePoint); VELOX_USER_CHECK_GT( charLength, 0, "Replacement is not a valid UTF-8 character"); VELOX_USER_CHECK_EQ( @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < value.size()) { - auto charLength = - tryGetCharLength(value.data() + pos, value.size() - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + value.data() + pos, value.size() - pos, codePoint); if (charLength < 0) { firstInvalidRow = row; return false; @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < input.size()) { - auto charLength = - tryGetCharLength(input.data() + pos, input.size() - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength( + input.data() + pos, input.size() - pos, codePoint); if (charLength > 0) { fixedWriter.append(std::string_view(input.data() + pos, charLength)); pos += charLength; diff --git a/velox/functions/prestosql/URIParser.cpp b/velox/functions/prestosql/URIParser.cpp index d15d98a8ef38..04a45791481f 100644 --- a/velox/functions/prestosql/URIParser.cpp +++ b/velox/functions/prestosql/URIParser.cpp @@ -15,6 +15,8 @@ */ #include "velox/functions/prestosql/URIParser.h" +#include "velox/external/utf8proc/utf8procImpl.h" +#include "velox/functions/lib/Utf8Utils.h" namespace facebook::velox::functions { @@ -44,6 +46,11 @@ Mask createMask(const std::vector& values) { return mask; } + +bool test(const Mask& mask, char value) { + return value < mask.size() && mask.test(value); +} + // a-z or A-Z. const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z'); // 0-9. @@ -135,7 +142,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { return false; } - if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) { + if (str[pos] != '%' || !test(kHex, str[pos + 1]) || + !test(kHex, str[pos + 2])) { return false; } @@ -145,7 +153,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { } // Helper function that consumes as much of `str` from `pos` as possible where a -// character passes mask or is part of a percent encoded character. +// character passes mask, is part of a percent encoded character, or is an +// allowed UTF-8 character. // // `pos` is updated to the first character in `str` that was not consumed and // `hasEncoded` is set to true if any percent encoded characters were @@ -157,7 +166,7 @@ void consume( int32_t& pos, bool& hasEncoded) { while (pos < len) { - if (mask.test(str[pos])) { + if (test(mask, str[pos])) { pos++; continue; } @@ -167,6 +176,29 @@ void consume( continue; } + // Masks cover all ASCII characters, check if this is an allowed UTF-8 + // character. + if ((unsigned char)str[pos] > 127) { + // Get the UTF-8 code point. + int32_t codePoint; + auto valid = tryGetUtf8CharLength(str + pos, len - pos, codePoint); + + // Check if it's a valid UTF-8 character. + // The range after ASCII characters up to 159 covers control characters + // which are not allowed. + if (valid > 0 && codePoint > 159) { + const auto category = utf8proc_get_property(codePoint)->category; + // White space characters are also not allowed. The range of categories + // excluded here are categories of white space. + if (category < UTF8PROC_CATEGORY_ZS || + category > UTF8PROC_CATEGORY_ZP) { + // Increment over the whole (potentially multi-byte) character. + pos += valid; + continue; + } + } + } + break; } } @@ -314,7 +346,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { while (posInAddress < len && numBytes < 16) { int32_t posInHex = posInAddress; for (int i = 0; i < 4; i++) { - if (posInHex == len || !kHex.test(str[posInHex])) { + if (posInHex == len || !test(kHex, str[posInHex])) { break; } @@ -350,7 +382,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { posInAddress = posInHex + 2; } } else { - if (posInHex == len || !kHex.test(str[posInHex + 1])) { + if (posInHex == len || !test(kHex, str[posInHex + 1])) { // Peak ahead, we can't end on a single ':'. return false; } @@ -392,7 +424,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { // Consume a string of hex digits. int32_t posInHex = posInAddress; while (posInHex < len) { - if (kHex.test(str[posInHex])) { + if (test(kHex, str[posInHex])) { posInHex++; } else { break; @@ -416,7 +448,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { int32_t posInSuffix = posInAddress; while (posInSuffix < len) { - if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) { + if (test(kIPVFutureSuffixOrUserInfo, str[posInSuffix])) { posInSuffix++; } else { break; @@ -467,7 +499,7 @@ void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInPort = pos; while (posInPort < len) { - if (kNum.test(str[posInPort])) { + if (test(kNum, str[posInPort])) { posInPort++; continue; } @@ -488,7 +520,7 @@ void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInIPV4Address = posInHost; if (tryConsumeIPV4Address(str, len, posInIPV4Address) && (posInIPV4Address == len || - kFollowingHost.test(str[posInIPV4Address]))) { + test(kFollowingHost, str[posInIPV4Address]))) { // reg-name and IPv4 addresses are hard to distinguish, a reg-name could // have a valid IPv4 address as a prefix, but treating that prefix as an // IPv4 address would make this URI invalid. We make sure that if we @@ -551,14 +583,14 @@ bool tryConsumeScheme( int32_t posInScheme = pos; // The scheme must start with a letter. - if (posInScheme == len || !kAlpha.test(str[posInScheme])) { + if (posInScheme == len || !test(kAlpha, str[posInScheme])) { return false; } // Consume the first letter. posInScheme++; - while (posInScheme < len && kScheme.test(str[posInScheme])) { + while (posInScheme < len && test(kScheme, str[posInScheme])) { posInScheme++; } diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index f5985f980c7d..4a6421d5bad0 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -80,8 +80,9 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { outputBuffer[outIndex++] = '+'; inputIndex++; } else { - const auto charLength = - tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + int32_t codePoint; + const auto charLength = tryGetUtf8CharLength( + inputBuffer + inputIndex, inputSize - inputIndex, codePoint); if (charLength > 0) { for (int i = 0; i < charLength; ++i) { charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); @@ -93,11 +94,11 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { // According to the Unicode standard the "maximal subpart of an // ill-formed subsequence" is the longest code unit subsequenece that is // either well-formed or of length 1. A replacement character should be - // written for each of these. In practice tryGetCharLength breaks most - // cases into maximal subparts, the exceptions are overlong encodings or - // subsequences outside the range of valid 4 byte sequences. In both - // these cases we should just write out a replacement character for - // every byte in the sequence. + // written for each of these. In practice tryGetUtf8CharLength breaks + // most cases into maximal subparts, the exceptions are overlong + // encodings or subsequences outside the range of valid 4 byte + // sequences. In both these cases we should just write out a + // replacement character for every byte in the sequence. size_t replaceCharactersToWriteOut = 1; if (inputIndex < inputSize - 1) { bool isMultipleInvalidSequences = @@ -108,13 +109,13 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { (inputBuffer[inputIndex] == '\xf0' && (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || // 0xf4 followed by a byte >= 0x90 looks valid to - // tryGetCharLength, but is actually outside the range of valid - // code points. + // tryGetUtf8CharLength, but is actually outside the range of + // valid code points. (inputBuffer[inputIndex] == '\xf4' && (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of - // multi-byte code points to tryGetCharLength, but are not part of - // any valid code point. + // multi-byte code points to tryGetUtf8CharLength, but are not + // part of any valid code point. (unsigned char)inputBuffer[inputIndex] > 0xf4 || inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 0439cc89758e..778bcf2355bf 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -347,7 +347,7 @@ TEST_F(URLFunctionsTest, extractHostRegName) { // Test minimal. EXPECT_EQ("a", extractHost("http://a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=", extractHost( @@ -357,9 +357,31 @@ TEST_F(URLFunctionsTest, extractHostRegName) { "123.456.789.012.abcdefg", extractHost("http://123.456.789.012.abcdefg")); // Test percent encoded. EXPECT_EQ("a b", extractHost("http://a%20b")); + // Valid UTF-8 in host reg name. + EXPECT_EQ("你好", extractHost("https://你好")); + // Valid UTF-8 in userinfo. + EXPECT_EQ("foo", extractHost("https://你好@foo")); - // Invalid character. + // Invalid ASCII character. EXPECT_EQ(std::nullopt, extractHost("http://a b")); + // Inalid UTF-8 in host reg name (it should be a 3 byte character but there's + // only 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8")); + // Inalid UTF-8 in userinfo (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82")); + // Valid UTF-8 in userinfo but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's white + // space: THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84")); + // Valid UTF-8 in userinfo but character is not allowed (it's white space: + // THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84@foo")); } TEST_F(URLFunctionsTest, extractPath) { @@ -380,11 +402,21 @@ TEST_F(URLFunctionsTest, extractPath) { EXPECT_EQ("foo", extractPath("foo")); EXPECT_EQ(std::nullopt, extractPath("BAD URL!")); EXPECT_EQ("", extractPath("http://www.yahoo.com")); - // All valid characters. + // All valid ASCII characters. EXPECT_EQ( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@", extractPath( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@")); + // Valid UTF-8 in path. + EXPECT_EQ("/你好", extractPath("https://foo.com/你好")); + // Inalid UTF-8 in path (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractPort) { @@ -430,11 +462,21 @@ TEST_F(URLFunctionsTest, extractQuery) { EXPECT_EQ("", extractQuery("http://www.yahoo.com?")); // Test non-empty query. EXPECT_EQ("a", extractQuery("http://www.yahoo.com?a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ []", extractQuery( "http://www.yahoo.com?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20[]")); + // Valid UTF-8 in query. + EXPECT_EQ("你好", extractQuery("https://foo.com?你好")); + // Inalid UTF-8 in query (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractFragment) { @@ -442,15 +484,25 @@ TEST_F(URLFunctionsTest, extractFragment) { return evaluateOnce("url_extract_fragment(c0)", url); }; - // Test empty query. + // Test empty fragment. EXPECT_EQ("", extractFragment("http://www.yahoo.com#")); - // Test non-empty query. + // Test non-empty fragment. EXPECT_EQ("a", extractFragment("http://www.yahoo.com#a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ []", extractFragment( "http://www.yahoo.com#abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20[]")); + // Valid UTF-8 in fgrament. + EXPECT_EQ("你好", extractFragment("https://foo.com#你好")); + // Inalid UTF-8 in fragment (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractParameter) { diff --git a/velox/functions/sparksql/Split.h b/velox/functions/sparksql/Split.h index 2cee345f77b2..cb8fcb076700 100644 --- a/velox/functions/sparksql/Split.h +++ b/velox/functions/sparksql/Split.h @@ -81,10 +81,11 @@ struct Split { size_t pos = 0; int32_t count = 0; while (pos < end && count < limit) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetUtf8CharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid - // character is the absolute value of result of `tryGetCharLength`. + // character is the absolute value of result of `tryGetUtf8CharLength`. charLength = -charLength; } result.add_item().setNoCopy(StringView(start + pos, charLength)); @@ -142,10 +143,13 @@ struct Split { // empty tail string at last, e.g., the result array for split('abc','d|') // is ["a","b","c",""]. if (size == 0) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = + tryGetUtf8CharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid - // character is the absolute value of result of `tryGetCharLength`. + // character is the absolute value of result of + // `tryGetUtf8CharLength`. charLength = -charLength; } offset += charLength;