diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 8dc64704d1d7d..f5985f980c7d2 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -16,15 +16,21 @@ #pragma once #include -#include -#include #include "velox/functions/Macros.h" -#include "velox/functions/lib/string/StringImpl.h" +#include "velox/functions/lib/Utf8Utils.h" #include "velox/functions/prestosql/URIParser.h" namespace facebook::velox::functions { namespace detail { +constexpr std::array kEncodedReplacementCharacterStrings = + {"%EF%BF%BD", + "%EF%BF%BD%EF%BF%BD", + "%EF%BF%BD%EF%BF%BD%EF%BF%BD", + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD"}; + FOLLY_ALWAYS_INLINE StringView submatch(const boost::cmatch& match, int idx) { const auto& sub = match[idx]; return StringView(sub.first, sub.length()); @@ -49,27 +55,86 @@ FOLLY_ALWAYS_INLINE void charEscape(unsigned char c, char* output) { /// * All other characters are converted to UTF-8 and the bytes are encoded /// as the string ``%XX`` where ``XX`` is the uppercase hexadecimal /// value of the UTF-8 byte. +/// * If the character is invalid UTF-8 each maximal subpart of an +/// ill-formed subsequence (defined below) is converted to %EF%BF%BD. template FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { auto inputSize = input.size(); - output.reserve(inputSize * 3); + // In the worst case every byte is an invalid UTF-8 character. + output.reserve(inputSize * kEncodedReplacementCharacterStrings[0].size()); auto inputBuffer = input.data(); auto outputBuffer = output.data(); + size_t inputIndex = 0; size_t outIndex = 0; - for (auto i = 0; i < inputSize; ++i) { - unsigned char p = inputBuffer[i]; + while (inputIndex < inputSize) { + unsigned char p = inputBuffer[inputIndex]; if ((p >= 'a' && p <= 'z') || (p >= 'A' && p <= 'Z') || (p >= '0' && p <= '9') || p == '-' || p == '_' || p == '.' || p == '*') { outputBuffer[outIndex++] = p; + inputIndex++; } else if (p == ' ') { outputBuffer[outIndex++] = '+'; + inputIndex++; } else { - charEscape(p, outputBuffer + outIndex); - outIndex += 3; + const auto charLength = + tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + if (charLength > 0) { + for (int i = 0; i < charLength; ++i) { + charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); + outIndex += 3; + } + + inputIndex += charLength; + } else { + // According to the Unicode standard the "maximal subpart of an + // ill-formed subsequence" is the longest code unit subsequenece that is + // either well-formed or of length 1. A replacement character should be + // written for each of these. In practice tryGetCharLength breaks most + // cases into maximal subparts, the exceptions are overlong encodings or + // subsequences outside the range of valid 4 byte sequences. In both + // these cases we should just write out a replacement character for + // every byte in the sequence. + size_t replaceCharactersToWriteOut = 1; + if (inputIndex < inputSize - 1) { + bool isMultipleInvalidSequences = + // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a + // value less than 0x90 is considered an overlong encoding. + (inputBuffer[inputIndex] == '\xe0' && + (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || + (inputBuffer[inputIndex] == '\xf0' && + (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || + // 0xf4 followed by a byte >= 0x90 looks valid to + // tryGetCharLength, but is actually outside the range of valid + // code points. + (inputBuffer[inputIndex] == '\xf4' && + (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || + // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of + // multi-byte code points to tryGetCharLength, but are not part of + // any valid code point. + (unsigned char)inputBuffer[inputIndex] > 0xf4 || + inputBuffer[inputIndex] == '\xc0' || + inputBuffer[inputIndex] == '\xc1'; + + if (isMultipleInvalidSequences) { + replaceCharactersToWriteOut = charLength * -1; + } + } + + const auto& replacementCharacterString = + kEncodedReplacementCharacterStrings + [replaceCharactersToWriteOut - 1]; + std::memcpy( + outputBuffer + outIndex, + replacementCharacterString.data(), + replacementCharacterString.size()); + outIndex += replacementCharacterString.size(); + + inputIndex += -charLength; + } } } output.resize(outIndex); diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 6830f53927824..0439cc89758e1 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -496,6 +496,35 @@ TEST_F(URLFunctionsTest, urlEncode) { urlEncode("http://\u30c6\u30b9\u30c8")); EXPECT_EQ("%7E%40%3A.-*_%2B+%E2%98%83", urlEncode("~@:.-*_+ \u2603")); EXPECT_EQ("test", urlEncode("test")); + // Test a single byte invalid UTF-8 character. + EXPECT_EQ("te%EF%BF%BDst", urlEncode("te\x88st")); + // Test a multi-byte invalid UTF-8 character. (If the first byte is between + // 0xe0 and 0xef, it should be a 3 byte character, but we only have 2 bytes + // here.) + EXPECT_EQ("te%EF%BF%BDst", urlEncode("te\xe0\xb8st")); + // Test an overlong 3 byte UTF-8 character + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xe0\x94")); + // Test an overlong 3 byte UTF-8 character with a continuation byte. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xe0\x94\x83")); + // Test an overlong 4 byte UTF-8 character + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\x84")); + // Test an overlong 4 byte UTF-8 character with continuation bytes. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\x84\x90\x90")); + // Test a 4 byte UTF-8 character outside the range of valid values. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xfa\x80\x80\x80")); + // Test the beginning of a 4 byte UTF-8 character followed by a + // non-continuation byte. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\xe0")); + // Test the invalid byte 0xc0. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xc0\x83")); + // Test the invalid byte 0xc1. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xc1\x83")); + // Test a 4 byte UTF-8 character that looks valid, but is actually outside the + // range of valid values. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xf4\x92\x83\x83")); } TEST_F(URLFunctionsTest, urlDecode) {