From 0bb581d556a336e0b63f112952ab7e800ba561db Mon Sep 17 00:00:00 2001 From: xumingming Date: Sun, 26 Nov 2023 20:14:56 +0800 Subject: [PATCH] Optimize LIKE when user specified escape char Currently we have optimization for LIKE only when user does not specify escape char, this commit provides optimization(kPrefix, kSuffix, kFixed, kSubstring) when user specified escape char. --- velox/functions/lib/Re2Functions.cpp | 296 ++++++++++++------ velox/functions/lib/Re2Functions.h | 75 ++++- .../functions/lib/tests/Re2FunctionsTest.cpp | 181 +++++++++-- 3 files changed, 428 insertions(+), 124 deletions(-) diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index 59bfd8451105..7299f86c72ae 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -389,7 +389,7 @@ class Re2SearchAndExtract final : public VectorFunction { // Match string 'input' with a fixed pattern (with no wildcard characters). bool matchExactPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() == pattern.size() && std::memcmp(input.data(), pattern.data(), length) == 0; @@ -398,7 +398,7 @@ bool matchExactPattern( // Match the first 'length' characters of string 'input' and prefix pattern. bool matchPrefixPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() >= length && std::memcmp(input.data(), pattern.data(), length) == 0; @@ -407,7 +407,7 @@ bool matchPrefixPattern( // Match the last 'length' characters of string 'input' and suffix pattern. bool matchSuffixPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() >= length && std::memcmp( @@ -418,21 +418,21 @@ bool matchSuffixPattern( bool matchSubstringPattern( const StringView& input, - const StringView& fixedPattern) { + const std::string& unescapedPattern) { return ( - std::string_view(input).find(std::string_view(fixedPattern)) != + std::string_view(input).find(std::string_view(unescapedPattern)) != std::string::npos); } template class OptimizedLike final : public VectorFunction { public: - OptimizedLike(StringView pattern, vector_size_t reducedPatternLength) + OptimizedLike(std::string pattern, vector_size_t reducedPatternLength) : pattern_{pattern}, reducedPatternLength_{reducedPatternLength} {} static bool match( const StringView& input, - const StringView& pattern, + const std::string& pattern, vector_size_t reducedPatternLength) { switch (P) { case PatternKind::kExactlyN: @@ -483,7 +483,7 @@ class OptimizedLike final : public VectorFunction { } private: - StringView pattern_; + std::string pattern_; vector_size_t reducedPatternLength_; }; @@ -592,34 +592,33 @@ class LikeGeneric final : public VectorFunction { auto applyRow = [&](const StringView& input, const StringView& pattern, const std::optional& escapeChar) -> bool { - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - vector_size_t reducedLength = patternMetadata.length; - - switch (patternMetadata.patternKind) { - case PatternKind::kExactlyN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kAtLeastN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kFixed: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kPrefix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSuffix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSubstring: - return OptimizedLike::match( - input, StringView(patternMetadata.fixedPattern), reducedLength); - default: - return applyWithRegex(input, pattern, escapeChar); - } + PatternMetadata patternMetadata = + determinePatternKind(pattern, escapeChar); + vector_size_t reducedLength = patternMetadata.length; + auto unescapedPattern = patternMetadata.unescapedPattern; + + switch (patternMetadata.patternKind) { + case PatternKind::kExactlyN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kAtLeastN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kFixed: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kPrefix: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kSuffix: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kSubstring: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + default: + return applyWithRegex(input, pattern, escapeChar); } - return applyWithRegex(input, pattern, escapeChar); }; context.ensureWritable(rows, type, localResult); @@ -970,10 +969,62 @@ std::vector> re2ExtractSignatures() { }; } -PatternMetadata determinePatternKind(StringView pattern) { +std::string unescape( + StringView pattern, + vector_size_t start, + vector_size_t end, + std::optional escapeChar) { + if (!escapeChar) { + return std::string(pattern.data(), start, end - start); + } + + std::ostringstream str; + bool isPreviousEscapeChar = false; + for (auto index = start; index < end; index++) { + auto current = pattern.data()[index]; + if (isPreviousEscapeChar) { + VELOX_USER_CHECK( + current == escapeChar || current == '_' || current == '%', + "Escape character must be followed by '%', '_' or the escape character itself") + str << current; + isPreviousEscapeChar = false; + } else { + if (current == escapeChar) { + isPreviousEscapeChar = true; + continue; + } else { + str << current; + } + } + } + + return str.str(); +} + +std::string unescape(StringView pattern, std::optional escapeChar) { + return unescape(pattern, 0, pattern.size(), escapeChar); +} + +bool PatternStringIterator::hasNext() { + return currentIndex_ < (int32_t)(pattern_.size() - 1); +} + +void PatternStringIterator::next() { + currentIndex_++; + previousState_ = state_; + + auto currentChar = pattern_.data()[currentIndex_]; + state_.isEscaping = !previousState_.isEscaping && currentChar == escapeChar_; + state_.isWildcard = !previousState_.isEscaping && + (currentChar == '_' || currentChar == '%') && currentChar != escapeChar_; +} + +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar) { vector_size_t patternLength = pattern.size(); - vector_size_t i = 0; - // Index of the first % or _ character. + + // Index of the first % or _ character(not escaped). vector_size_t wildcardStart = -1; // Count of wildcard character sequences in pattern. vector_size_t numWildcardSequences = 0; @@ -981,40 +1032,80 @@ PatternMetadata determinePatternKind(StringView pattern) { vector_size_t fixedPatternStart = -1; // Index of the last character in the fixed pattern, used to retrieve the // fixed string for patterns of type kSubstring. - vector_size_t fixedPatternEnd = 0; + vector_size_t fixedPatternEnd = -1; // Total number of % characters. vector_size_t anyCharacterWildcardCount = 0; // Total number of _ characters. vector_size_t singleCharacterWildcardCount = 0; - auto patternStr = pattern.data(); - while (i < patternLength) { - if (patternStr[i] == '%' || patternStr[i] == '_') { - if (wildcardStart == -1) { - wildcardStart = i; + PatternStringIterator iterator{pattern, escapeChar}; + bool fallbackToGeneric = false; + + // Return false if failed to handle the new fixed pattern(need to fallback), + // true otherwise. + auto handleNewFixedPattern = [&]() -> bool { + // Record the first fixed pattern start. + if (fixedPatternStart == -1) { + fixedPatternStart = iterator.currentIndex(); + } else { + // This is not the first fixed pattern, not supported, so fallback. + if (iterator.previousState().isWildcard) { + fallbackToGeneric = true; + return false; } - numWildcardSequences++; - // Look till the last contiguous wildcard character, starting from this - // index, is found, or the end of pattern is reached. - while (i < patternLength && - (patternStr[i] == '%' || patternStr[i] == '_')) { - singleCharacterWildcardCount += (patternStr[i] == '_'); - anyCharacterWildcardCount += (patternStr[i] == '%'); - i++; + } + + return true; + }; + + while (iterator.hasNext()) { + iterator.next(); + auto current = iterator.current(); + + if (iterator.previousState().isEscaping) { + // The char follows escapeChar can only be one of (%, _, escapeChar). + if (current != '%' && current != '_' && current != escapeChar) { + fallbackToGeneric = true; + break; + } + + if (!handleNewFixedPattern()) { + break; } } else { - // Ensure that pattern has a single fixed pattern. - if (fixedPatternStart != -1) { - return PatternMetadata{PatternKind::kGeneric, 0}; + // Escaping char, continue. + if (iterator.state().isEscaping) { + continue; } - // Look till the end of fixed pattern, starting from this index, is found, - // or the end of pattern is reached. - fixedPatternStart = i; - while (i < patternLength && - (patternStr[i] != '%' && patternStr[i] != '_')) { - i++; + + if (iterator.state().isWildcard) { + if (wildcardStart == -1) { + wildcardStart = iterator.currentIndex(); + } + + singleCharacterWildcardCount += (iterator.current() == '_'); + anyCharacterWildcardCount += (iterator.current() == '%'); + numWildcardSequences += (!iterator.previousState().isWildcard ? 1 : 0); + + // Mark the end of the fixed pattern. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; + } + } else { + if (!handleNewFixedPattern()) { + break; + } } - fixedPatternEnd = i - 1; + } + } + + if (fallbackToGeneric) { + return PatternMetadata{PatternKind::kGeneric, 0}; + } else { + // The pattern end may not been marked if there is no wildcard char after + // pattern start, so we mark it here. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; } } @@ -1031,8 +1122,13 @@ PatternMetadata determinePatternKind(StringView pattern) { // At this point pattern contains exactly one fixed pattern. // Pattern contains no wildcard characters (is a fixed pattern). if (wildcardStart == -1) { - return PatternMetadata{PatternKind::kFixed, patternLength}; + auto unescapedPattern = unescape(pattern, 0, patternLength, escapeChar); + return PatternMetadata{ + PatternKind::kFixed, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } + // Pattern is generic if it has '_' wildcard characters and a fixed pattern. if (singleCharacterWildcardCount) { return PatternMetadata{PatternKind::kGeneric, 0}; @@ -1040,19 +1136,28 @@ PatternMetadata determinePatternKind(StringView pattern) { // Classify pattern as prefix, fixed center, or suffix pattern based on the // position and count of the wildcard character sequence and fixed pattern. if (fixedPatternStart < wildcardStart) { - return PatternMetadata{PatternKind::kPrefix, wildcardStart}; + auto unescapedPattern = unescape(pattern, 0, wildcardStart, escapeChar); + return PatternMetadata{ + PatternKind::kPrefix, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } // if numWildcardSequences > 1, then fixed pattern must be in between them. if (numWildcardSequences == 2) { + auto unescapedPattern = + unescape(pattern, fixedPatternStart, fixedPatternEnd + 1, escapeChar); return PatternMetadata{ PatternKind::kSubstring, - 0, - std::string( - pattern.data() + fixedPatternStart, - fixedPatternEnd + 1 - fixedPatternStart)}; + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } + auto unescapedPattern = + unescape(pattern, fixedPatternStart, patternLength, escapeChar); + return PatternMetadata{ - PatternKind::kSuffix, patternLength - fixedPatternStart}; + PatternKind::kSuffix, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } std::shared_ptr makeLike( @@ -1095,34 +1200,33 @@ std::shared_ptr makeLike( } auto pattern = constantPattern->as>()->valueAt(0); - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - PatternKind patternKind = patternMetadata.patternKind; - vector_size_t reducedLength = patternMetadata.length; - - switch (patternKind) { - case PatternKind::kExactlyN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kAtLeastN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kFixed: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kPrefix: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kSuffix: - return std::make_shared>( - pattern, reducedLength); - default: - - return std::make_shared(pattern, escapeChar); - } + PatternMetadata patternMetadata = determinePatternKind(pattern, escapeChar); + PatternKind patternKind = patternMetadata.patternKind; + vector_size_t reducedLength = patternMetadata.length; + auto unescapedPattern = patternMetadata.unescapedPattern; + + switch (patternKind) { + case PatternKind::kExactlyN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kAtLeastN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kFixed: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kPrefix: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kSuffix: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kSubstring: + return std::make_shared>( + unescapedPattern, reducedLength); + default: + return std::make_shared(pattern, escapeChar); } - - return std::make_shared(pattern, escapeChar); } std::vector> likeSignatures() { diff --git a/velox/functions/lib/Re2Functions.h b/velox/functions/lib/Re2Functions.h index e7d2adcacfd3..fd9bd93797f8 100644 --- a/velox/functions/lib/Re2Functions.h +++ b/velox/functions/lib/Re2Functions.h @@ -48,12 +48,14 @@ enum class PatternKind { struct PatternMetadata { PatternKind patternKind; - // Contains the length of the fixed pattern for patterns of kind kFixed, - // kPrefix, and kSuffix. Contains the count of wildcard character '_' for - // patterns of kind kExactlyN and kAtLeastN. Contains 0 otherwise. + // Contains the length of the unescaped fixed pattern for patterns of kind + // kFixed, kPrefix, kSuffix and kSubstring. Contains the count of wildcard + // character '_' for patterns of kind kExactlyN and kAtLeastN. Contains 0 + // otherwise. vector_size_t length; - // Contains the fixed pattern in patterns of kind kSubstring. - std::string fixedPattern = ""; + // Contains the unescaped fixed pattern in patterns of kind kFixed, kPrefix, + // kSuffix and kSubstring. + std::string unescapedPattern = ""; }; /// The functions in this file use RE2 as the regex engine. RE2 is fast, but @@ -114,7 +116,68 @@ std::vector> re2ExtractSignatures(); /// prefix, and suffix patterns. Return the pair {pattern kind, number of '_' /// characters} for patterns with wildcard characters only. Return /// {kGenericPattern, 0} for generic patterns). -PatternMetadata determinePatternKind(StringView pattern); +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar); + +/// Return the unescaped string for the specified string range, if escape char +/// is not specified just return the corresponding substring. +std::string unescape( + StringView pattern, + vector_size_t start, + vector_size_t end, + std::optional escapeChar); +std::string unescape(StringView pattern, std::optional escapeChar); + +/// An Iterator that provides methods(hasNext, next) to iterate through a +/// pattern string. +class PatternStringIterator { + public: + struct State { + // Is current char the escape char? + // NOTE: If escape char is set as '\', for pattern '\\', the first '\' is + // an escaping char, the second is not, it is just a literal '\' + bool isEscaping = false; + // Is current char the wildcard char? + // NOTE: If escape char is set as '\', for pattern '\%%', the first '%' is + // not a wildcard, just a literal '%', the second '%' is a wildcard. + bool isWildcard = false; + }; + + explicit PatternStringIterator( + StringView pattern, + std::optional escapeChar) + : pattern_(pattern), escapeChar_(escapeChar) {} + + bool hasNext(); + void next(); + + char currentIndex() { + return currentIndex_; + } + + char current() { + return pattern_.data()[currentIndex_]; + } + + State state() { + return state_; + } + + State previousState() { + return previousState_; + } + + private: + StringView pattern_; + std::optional escapeChar_; + + int32_t currentIndex_ = -1; + // State of current char. + State state_; + // State of previous char. + State previousState_; +}; std::shared_ptr makeLike( const std::string& name, diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 57376d3f44e6..accc491bc29d 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -460,15 +460,33 @@ TEST_F(Re2FunctionsTest, likePattern) { testLike("abc", "MEDIUM POLISHED%", false); } +TEST_F(Re2FunctionsTest, unescape) { + EXPECT_EQ("%", unescape(R"(\%)", std::make_optional('\\'))); + EXPECT_EQ("%%%", unescape(R"(\%\%\%)", std::make_optional('\\'))); + EXPECT_EQ("a%b%c%d", unescape(R"(a\%b\%c\%d)", std::make_optional('\\'))); + EXPECT_EQ("a_b_c", unescape(R"(a\_b\_c)", std::make_optional('\\'))); + EXPECT_EQ("a_b_c", unescape(R"(%%a\_b\_c)", 2, 9, std::make_optional('\\'))); + + EXPECT_EQ("%%%", unescape(R"(%%%%%%)", std::make_optional('%'))); +} + TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { - auto testPattern = [&](StringView pattern, - PatternKind patternKind, - vector_size_t length, - StringView fixedPattern = "") { - PatternMetadata patternMetadata = determinePatternKind(pattern); + auto testPattern = + [&](StringView pattern, PatternKind patternKind, vector_size_t length) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, length); + }; + + auto testPatternString = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); EXPECT_EQ(patternMetadata.patternKind, patternKind); - EXPECT_EQ(patternMetadata.length, length); - EXPECT_EQ(patternMetadata.fixedPattern, fixedPattern); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); + EXPECT_EQ(patternMetadata.unescapedPattern, fixedPattern); }; testPattern("_", PatternKind::kExactlyN, 1); @@ -477,12 +495,14 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%", PatternKind::kAtLeastN, 0); testPattern("__%%__", PatternKind::kAtLeastN, 4); testPattern("%_%%", PatternKind::kAtLeastN, 1); + testPattern("%%%%%%%%%%%%", PatternKind::kAtLeastN, 0); - testPattern("presto", PatternKind::kFixed, 6); - testPattern("hello", PatternKind::kFixed, 5); - testPattern("a", PatternKind::kFixed, 1); - testPattern("helloPrestoWorld", PatternKind::kFixed, 16); - testPattern("aBcD", PatternKind::kFixed, 4); + testPatternString("presto", PatternKind::kFixed, "presto"); + testPatternString("hello", PatternKind::kFixed, "hello"); + testPatternString("a", PatternKind::kFixed, "a"); + testPatternString( + "helloPrestoWorld", PatternKind::kFixed, "helloPrestoWorld"); + testPatternString("aBcD", PatternKind::kFixed, "aBcD"); testPattern("presto%", PatternKind::kPrefix, 6); testPattern("hello%%", PatternKind::kPrefix, 5); @@ -496,11 +516,11 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%helloPrestoWorld", PatternKind::kSuffix, 16); testPattern("%aBcD", PatternKind::kSuffix, 4); - testPattern("%presto%%", PatternKind::kSubstring, 0, "presto"); - testPattern("%%hello%", PatternKind::kSubstring, 0, "hello"); - testPattern("%%%aAb\n%", PatternKind::kSubstring, 0, "aAb\n"); - testPattern( - "%helloPrestoWorld%%%", PatternKind::kSubstring, 0, "helloPrestoWorld"); + testPatternString("%presto%%", PatternKind::kSubstring, "presto"); + testPatternString("%%hello%", PatternKind::kSubstring, "hello"); + testPatternString("%%%aAb\n%", PatternKind::kSubstring, "aAb\n"); + testPatternString( + "%helloPrestoWorld%%%", PatternKind::kSubstring, "helloPrestoWorld"); testPattern("_b%%__", PatternKind::kGeneric, 0); testPattern("%_%p", PatternKind::kGeneric, 0); @@ -515,6 +535,36 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("_aBcD", PatternKind::kGeneric, 0); } +TEST_F(Re2FunctionsTest, likeDeterminePatternKindWithEscapeChar) { + auto testPattern = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = determinePatternKind(pattern, '\\'); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); + EXPECT_EQ(patternMetadata.unescapedPattern, fixedPattern); + }; + + testPattern(R"(\_)", PatternKind::kFixed, "_"); + testPattern(R"(\_\_\_\_)", PatternKind::kFixed, "____"); + testPattern(R"(a\_\_b\_\_c)", PatternKind::kFixed, "a__b__c"); + + testPattern(R"(\%)", PatternKind::kFixed, "%"); + testPattern(R"(\%\%\%)", PatternKind::kFixed, "%%%"); + testPattern(R"(a\%b\%c\%d)", PatternKind::kFixed, "a%b%c%d"); + + testPattern(R"(\_\_%%)", PatternKind::kPrefix, "__"); + testPattern(R"(a\_b\_c%%)", PatternKind::kPrefix, "a_b_c"); + + testPattern(R"(%%\_\_)", PatternKind::kSuffix, "__"); + testPattern(R"(%%a\_b\_c)", PatternKind::kSuffix, "a_b_c"); + testPattern(R"(%\_\%)", PatternKind::kSuffix, "_%"); + + testPattern(R"(%\_%%)", PatternKind::kSubstring, "_"); + testPattern(R"(%\_\%%%)", PatternKind::kSubstring, "_%"); + testPattern(R"(%\_ab\%%%)", PatternKind::kSubstring, "_ab%"); +} + TEST_F(Re2FunctionsTest, likePatternWildcard) { testLike("", "", true); testLike("", "%", true); @@ -573,6 +623,15 @@ TEST_F(Re2FunctionsTest, likePatternFixed) { testLike("\nabcd\n", "\nabc\nd\n", false); testLike("\nab\tcd\b", "\nabcd\b", false); + // Test literal '_' & '%' in pattern. + testLike("a", R"(\_)", '\\', false); + testLike("_b", R"(\_b)", '\\', true); + testLike("abc_d", R"(abc\_d)", '\\', true); + + testLike("a", R"(\%)", '\\', false); + testLike("abc%d", R"(abc\%d)", '\\', true); + testLike("abc%d", R"(a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input, true); } @@ -605,6 +664,15 @@ TEST_F(Re2FunctionsTest, likePatternPrefix) { testLike("\nabc\nde\n", "ab\nc%", false); testLike("\nabc\nde\n", "abc%", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(\_%)", '\\', true); + testLike("_bcd", R"(\_b%)", '\\', true); + testLike("abc_defg", R"(abc\_d%)", '\\', true); + + testLike("%ab", R"(\%%)", '\\', true); + testLike("abc%defg", R"(abc\%d%)", '\\', true); + testLike("abc%defg", R"(a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input + generateString(kAnyWildcardCharacter), true); } @@ -637,6 +705,15 @@ TEST_F(Re2FunctionsTest, likePatternSuffix) { testLike("\nabcde\n", "%d\n", false); testLike("\nabcde\n", "%e_\n", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(%\_)", '\\', true); + testLike("cd_b", R"(%\_b)", '\\', true); + testLike("efgabc_d", R"(%abc\_d)", '\\', true); + + testLike("ab%", R"(%\%)", '\\', true); + testLike("efgabc%d", R"(%abc\%d)", '\\', true); + testLike("abc%defg", R"(%a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike(input, generateString(kAnyWildcardCharacter) + input, true); } @@ -668,6 +745,14 @@ TEST_F(Re2FunctionsTest, likeSubstringPattern) { testLike("\nabcde\n", "%%d\n%", false); testLike("\nabcde\n", "%%e_\n%%", false); + // Test literal '_' & '%' in pattern. + testLike("cd_be", R"(%\_b%)", '\\', true); + testLike("efgabc_dhi", R"(%abc\_d%)", '\\', true); + + testLike("ab%cd", R"(%\%%)", '\\', true); + testLike("efgabc%dhi", R"(%abc\%d%)", '\\', true); + testLike("abc%defg", R"(%a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike( input, @@ -1045,12 +1130,13 @@ TEST_F(Re2FunctionsTest, tryException) { // Make sure we do not compile more than kMaxCompiledRegexes. TEST_F(Re2FunctionsTest, likeRegexLimit) { - VectorPtr pattern = makeFlatVector(26); - VectorPtr input = makeFlatVector(26); + int count = 26; + VectorPtr pattern = makeFlatVector(count); + VectorPtr input = makeFlatVector(count); VectorPtr result; auto flatInput = input->asFlatVector(); - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { flatInput->set(i, ""); } @@ -1077,14 +1163,14 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { auto verifyNoRegexCompilationForPattern = [&](PatternKind patternKind) { // Over 20 all optimized, will pass. - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { std::string patternAtIdx = getPatternAtIdx(patternKind, i); flatPattern->set(i, StringView(patternAtIdx)); } result = evaluate("like(c0 , c1)", makeRowVector({input, pattern})); // Pattern '%%%', of type kAtleastN, matches with empty input. assertEqualVectors( - makeConstant((patternKind == PatternKind::kAtLeastN), 26), result); + makeConstant((patternKind == PatternKind::kAtLeastN), count), result); }; // Infer regex compilation does not happen for optimized patterns by verifying @@ -1150,5 +1236,56 @@ TEST_F(Re2FunctionsTest, invalidEscapeChar) { } } +TEST_F(Re2FunctionsTest, patternStringIteratorHasNext) { + std::string pattern = "_"; + PatternStringIterator iterator{StringView{pattern}, '#'}; + + EXPECT_EQ(true, iterator.hasNext()); + + iterator.next(); + EXPECT_EQ(false, iterator.hasNext()); +} + +TEST_F(Re2FunctionsTest, patternStringIteratorIsEscaping) { + std::string pattern = "####"; + PatternStringIterator iterator{StringView{pattern}, '#'}; + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(true, iterator.previousState().isEscaping); + EXPECT_EQ(false, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(false, iterator.previousState().isEscaping); + EXPECT_EQ(true, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(true, iterator.previousState().isEscaping); + EXPECT_EQ(false, iterator.state().isEscaping); +} + +TEST_F(Re2FunctionsTest, patternStringIteratorIsWildcard) { + std::string pattern = "%%%%"; + PatternStringIterator iterator{StringView{pattern}, '%'}; + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(false, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(false, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); +} + } // namespace } // namespace facebook::velox::functions