diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 17a26a633f5fb..2aa1f31fd6e3e 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -61,7 +61,7 @@ int firstByteCharLength(const char* u_input) { } // namespace -int32_t tryGetCharLength(const char* input, int64_t size) { +int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint) { VELOX_DCHECK_NOT_NULL(input); VELOX_DCHECK_GT(size, 0); @@ -72,6 +72,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 1) { // Normal ASCII: 0xxx_xxxx. + codePoint = input[0]; return 1; } @@ -89,7 +90,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 2) { // 110x_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); + codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); // Fail if overlong encoding. return codePoint < 0x80 ? -2 : 2; } @@ -106,7 +107,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 3) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00001111) << 12) | + codePoint = ((firstByte & 0b00001111) << 12) | ((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111); // Surrogates are invalid. @@ -132,7 +133,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 4) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00000111) << 18) | + codePoint = ((firstByte & 0b00000111) << 18) | ((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) | (forthByte & 0b00111111); // Fail if overlong encoding or above upper bound of Unicode. diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 369e1151e93fa..6373477887fdc 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -45,12 +45,14 @@ namespace facebook::velox::functions { /// /// @param input Pointer to the first byte of the code point. Must not be null. /// @param size Number of available bytes. Must be greater than zero. +/// @param codePoint Populated with the code point it refers to. This is only +/// valid if the return value is positive. /// @return the length of the code point or negative the number of bytes in the /// invalid UTF-8 sequence. /// /// Adapted from tryGetCodePointAt in /// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java -int32_t tryGetCharLength(const char* input, int64_t size); +int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint); /// Return the length in byte of the next UTF-8 encoded character at the /// beginning of `string`. If the beginning of `string` is not valid UTF-8 diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 4330d1d9bbd2c..0fe0e4697cb51 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -21,53 +21,62 @@ namespace facebook::velox::functions { namespace { TEST(Utf8Test, tryCharLength) { + int32_t codepoint; // Single-byte ASCII character. - ASSERT_EQ(1, tryGetCharLength("Hello", 5)); + ASSERT_EQ(1, tryGetCharLength("Hello", 5, codepoint)); + ASSERT_EQ('H', codepoint); // 2-byte character. British pound sign. static const char* kPound = "\u00A3tail"; - ASSERT_EQ(2, tryGetCharLength(kPound, 5)); + ASSERT_EQ(2, tryGetCharLength(kPound, 5, codepoint)); + ASSERT_EQ(0xA3, codepoint); // First byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound, 1)); + ASSERT_EQ(-1, tryGetCharLength(kPound, 1, codepoint)); // Second byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5)); + ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5, codepoint)); // ASCII character 't' after the pound sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5)); + ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5, codepoint)); // 3-byte character. Euro sign. static const char* kEuro = "\u20ACtail"; - ASSERT_EQ(3, tryGetCharLength(kEuro, 5)); + ASSERT_EQ(3, tryGetCharLength(kEuro, 5, codepoint)); + ASSERT_EQ(0x20AC, codepoint); // First byte or first 2 bytes alone are not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro, 1)); - ASSERT_EQ(-2, tryGetCharLength(kEuro, 2)); + ASSERT_EQ(-1, tryGetCharLength(kEuro, 1, codepoint)); + ASSERT_EQ(-2, tryGetCharLength(kEuro, 2, codepoint)); // Byte sequence starting from 2nd or 3rd byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5)); - ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5)); - ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5)); + ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5, codepoint)); + ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5, codepoint)); + ASSERT_EQ('t', codepoint); // ASCII character 't' after the euro sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5)); + ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5, codepoint)); + ASSERT_EQ('a', codepoint); // 4-byte character. Musical symbol F CLEF. static const char* kClef = "\U0001D122tail"; - ASSERT_EQ(4, tryGetCharLength(kClef, 5)); + ASSERT_EQ(4, tryGetCharLength(kClef, 5, codepoint)); + ASSERT_EQ(0x1D122, codepoint); // First byte, first 2 bytes, or first 3 bytes alone are not a valid // character. - ASSERT_EQ(-1, tryGetCharLength(kClef, 1)); - ASSERT_EQ(-2, tryGetCharLength(kClef, 2)); - ASSERT_EQ(-3, tryGetCharLength(kClef, 3)); + ASSERT_EQ(-1, tryGetCharLength(kClef, 1, codepoint)); + ASSERT_EQ(-2, tryGetCharLength(kClef, 2, codepoint)); + ASSERT_EQ(-3, tryGetCharLength(kClef, 3, codepoint)); // Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3, codepoint)); // ASCII character 't' after the clef sign is valid. - ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5)); + ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5, codepoint)); + ASSERT_EQ('t', codepoint); // Test overlong encoding. auto tryCharLength = [](const std::vector& bytes) { + int32_t codepoint; return tryGetCharLength( - reinterpret_cast(bytes.data()), bytes.size()); + reinterpret_cast(bytes.data()), bytes.size(), codepoint); }; // 2-byte encoding of 0x2F. diff --git a/velox/functions/prestosql/FromUtf8.cpp b/velox/functions/prestosql/FromUtf8.cpp index c538db0229618..0db0564bdbbe5 100644 --- a/velox/functions/prestosql/FromUtf8.cpp +++ b/velox/functions/prestosql/FromUtf8.cpp @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction { auto replacement = decoded.valueAt(row); if (!replacement.empty()) { + int32_t codePoint; auto charLength = - tryGetCharLength(replacement.data(), replacement.size()); + tryGetCharLength(replacement.data(), replacement.size(), codePoint); VELOX_USER_CHECK_GT( charLength, 0, "Replacement is not a valid UTF-8 character"); VELOX_USER_CHECK_EQ( @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < value.size()) { + int32_t codePoint; auto charLength = - tryGetCharLength(value.data() + pos, value.size() - pos); + tryGetCharLength(value.data() + pos, value.size() - pos, codePoint); if (charLength < 0) { firstInvalidRow = row; return false; @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < input.size()) { + int32_t codePoint; auto charLength = - tryGetCharLength(input.data() + pos, input.size() - pos); + tryGetCharLength(input.data() + pos, input.size() - pos, codePoint); if (charLength > 0) { fixedWriter.append(std::string_view(input.data() + pos, charLength)); pos += charLength; diff --git a/velox/functions/prestosql/URIParser.cpp b/velox/functions/prestosql/URIParser.cpp index 3f56a7d41f5e8..323e348fa7050 100644 --- a/velox/functions/prestosql/URIParser.cpp +++ b/velox/functions/prestosql/URIParser.cpp @@ -15,6 +15,8 @@ */ #include "velox/functions/prestosql/URIParser.h" +#include "velox/external/utf8proc/utf8procImpl.h" +#include "velox/functions/lib/Utf8Utils.h" namespace facebook::velox::functions { @@ -40,6 +42,11 @@ Mask createMask(const std::vector& values) { return mask; } + +bool test(const Mask& mask, char value) { + return value < mask.size() && mask.test(value); +} + // a-z or A-Z. const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z'); // 0-9. @@ -128,7 +135,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { return false; } - if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) { + if (str[pos] != '%' || !test(kHex, str[pos + 1]) || + !test(kHex, str[pos + 2])) { return false; } @@ -138,7 +146,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { } // Helper function that consumes as much of `str` from `pos` as possible where a -// character passes mask or is part of a percent encoded character. +// character passes mask, is part of a percent encoded character, or is an +// allowed UTF-8 character. // // `pos` is updated to the first character in `str` that was not consumed and // `hasEncoded` is set to true if any percent encoded characters were @@ -150,7 +159,7 @@ void consume( int32_t& pos, bool& hasEncoded) { while (pos < len) { - if (mask.test(str[pos])) { + if (test(mask, str[pos])) { pos++; continue; } @@ -160,6 +169,29 @@ void consume( continue; } + // Masks cover all ASCII characters, check if this is an allowed UTF-8 + // character. + // The range after ASCII characters up to 159 covers control characters + // which are not allowed. + if ((unsigned char)str[pos] > 159) { + // Get the UTF-8 code point. + int32_t codePoint; + auto valid = tryGetCharLength(str + pos, len - pos, codePoint); + + // Check if it's a valid UTF-8 character. + if (valid > 0) { + const auto category = utf8proc_get_property(codePoint)->category; + // White space characters are also not allowed. The range of categories + // excluded here are categories of white space. + if (category < UTF8PROC_CATEGORY_ZS || + category > UTF8PROC_CATEGORY_ZP) { + // Increment over the whole (potentially multi-byte) character. + pos += valid; + continue; + } + } + } + break; } } @@ -297,7 +329,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { while (posInAddress < len && numBytes < 16) { int32_t posInHex = posInAddress; for (int i = 0; i < 4; i++) { - if (posInHex == len || !kHex.test(str[posInHex])) { + if (posInHex == len || !test(kHex, str[posInHex])) { break; } @@ -333,7 +365,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { posInAddress = posInHex + 2; } } else { - if (posInHex == len || !kHex.test(str[posInHex + 1])) { + if (posInHex == len || !test(kHex, str[posInHex + 1])) { // Peak ahead, we can't end on a single ':'. return false; } @@ -375,7 +407,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { // Consume a string of hex digits. int32_t posInHex = posInAddress; while (posInHex < len) { - if (kHex.test(str[posInHex])) { + if (test(kHex, str[posInHex])) { posInHex++; } else { break; @@ -399,7 +431,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { int32_t posInSuffix = posInAddress; while (posInSuffix < len) { - if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) { + if (test(kIPVFutureSuffixOrUserInfo, str[posInSuffix])) { posInSuffix++; } else { break; @@ -450,7 +482,7 @@ void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInPort = pos; while (posInPort < len) { - if (kNum.test(str[posInPort])) { + if (test(kNum, str[posInPort])) { posInPort++; continue; } @@ -471,7 +503,7 @@ void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInIPV4Address = posInHost; if (tryConsumeIPV4Address(str, len, posInIPV4Address) && (posInIPV4Address == len || - kFollowingHost.test(str[posInIPV4Address]))) { + test(kFollowingHost, str[posInIPV4Address]))) { // reg-name and IPv4 addresses are hard to distinguish, a reg-name could // have a valid IPv4 address as a prefix, but treating that prefix as an // IPv4 address would make this URI invalid. We make sure that if we @@ -534,14 +566,14 @@ bool tryConsumeScheme( int32_t posInScheme = pos; // The scheme must start with a letter. - if (posInScheme == len || !kAlpha.test(str[posInScheme])) { + if (posInScheme == len || !test(kAlpha, str[posInScheme])) { return false; } // Consume the first letter. posInScheme++; - while (posInScheme < len && kScheme.test(str[posInScheme])) { + while (posInScheme < len && test(kScheme, str[posInScheme])) { posInScheme++; } diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 1edff6f30a142..77aeab41ba91c 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -74,8 +74,9 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { outputBuffer[outIndex++] = '+'; inputIndex++; } else { - const auto charLength = - tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + int32_t codePoint; + const auto charLength = tryGetCharLength( + inputBuffer + inputIndex, inputSize - inputIndex, codePoint); if (charLength > 0) { for (int i = 0; i < charLength; ++i) { charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 3f116d9632e06..13a89be93834a 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -347,7 +347,7 @@ TEST_F(URLFunctionsTest, extractHostRegName) { // Test minimal. EXPECT_EQ("a", extractHost("http://a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=", extractHost( @@ -357,9 +357,31 @@ TEST_F(URLFunctionsTest, extractHostRegName) { "123.456.789.012.abcdefg", extractHost("http://123.456.789.012.abcdefg")); // Test percent encoded. EXPECT_EQ("a b", extractHost("http://a%20b")); + // Valid UTF-8 in host reg name. + EXPECT_EQ("你好", extractHost("https://你好")); + // Valid UTF-8 in userinfo. + EXPECT_EQ("foo", extractHost("https://你好@foo")); - // Invalid character. + // Invalid ASCII character. EXPECT_EQ(std::nullopt, extractHost("http://a b")); + // Inalid UTF-8 in host reg name (it should be a 3 byte character but there's + // only 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8")); + // Inalid UTF-8 in userinfo (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82")); + // Valid UTF-8 in userinfo but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's white + // space: THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84")); + // Valid UTF-8 in userinfo but character is not allowed (it's white space: + // THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84@foo")); } TEST_F(URLFunctionsTest, extractPath) { @@ -380,11 +402,21 @@ TEST_F(URLFunctionsTest, extractPath) { EXPECT_EQ("foo", extractPath("foo")); EXPECT_EQ(std::nullopt, extractPath("BAD URL!")); EXPECT_EQ("", extractPath("http://www.yahoo.com")); - // All valid characters. + // All valid ASCII characters. EXPECT_EQ( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@", extractPath( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@")); + // Valid UTF-8 in path. + EXPECT_EQ("/你好", extractPath("https://foo.com/你好")); + // Inalid UTF-8 in path (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractPort) { @@ -430,11 +462,21 @@ TEST_F(URLFunctionsTest, extractQuery) { EXPECT_EQ("", extractQuery("http://www.yahoo.com?")); // Test non-empty query. EXPECT_EQ("a", extractQuery("http://www.yahoo.com?a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", extractQuery( "http://www.yahoo.com?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); + // Valid UTF-8 in query. + EXPECT_EQ("你好", extractQuery("https://foo.com?你好")); + // Inalid UTF-8 in query (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractFragment) { @@ -442,15 +484,25 @@ TEST_F(URLFunctionsTest, extractFragment) { return evaluateOnce("url_extract_fragment(c0)", url); }; - // Test empty query. + // Test empty fragment. EXPECT_EQ("", extractFragment("http://www.yahoo.com#")); - // Test non-empty query. + // Test non-empty fragment. EXPECT_EQ("a", extractFragment("http://www.yahoo.com#a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", extractFragment( "http://www.yahoo.com#abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); + // Valid UTF-8 in fgrament. + EXPECT_EQ("你好", extractFragment("https://foo.com#你好")); + // Inalid UTF-8 in fragment (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractParameter) { diff --git a/velox/functions/sparksql/Split.h b/velox/functions/sparksql/Split.h index 2cee345f77b28..502735e258f8f 100644 --- a/velox/functions/sparksql/Split.h +++ b/velox/functions/sparksql/Split.h @@ -81,7 +81,8 @@ struct Split { size_t pos = 0; int32_t count = 0; while (pos < end && count < limit) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetCharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid // character is the absolute value of result of `tryGetCharLength`. @@ -142,7 +143,8 @@ struct Split { // empty tail string at last, e.g., the result array for split('abc','d|') // is ["a","b","c",""]. if (size == 0) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetCharLength(start + pos, end - pos, codePOint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid // character is the absolute value of result of `tryGetCharLength`.