diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index 59bfd8451105..7299f86c72ae 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -389,7 +389,7 @@ class Re2SearchAndExtract final : public VectorFunction { // Match string 'input' with a fixed pattern (with no wildcard characters). bool matchExactPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() == pattern.size() && std::memcmp(input.data(), pattern.data(), length) == 0; @@ -398,7 +398,7 @@ bool matchExactPattern( // Match the first 'length' characters of string 'input' and prefix pattern. bool matchPrefixPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() >= length && std::memcmp(input.data(), pattern.data(), length) == 0; @@ -407,7 +407,7 @@ bool matchPrefixPattern( // Match the last 'length' characters of string 'input' and suffix pattern. bool matchSuffixPattern( StringView input, - StringView pattern, + const std::string& pattern, vector_size_t length) { return input.size() >= length && std::memcmp( @@ -418,21 +418,21 @@ bool matchSuffixPattern( bool matchSubstringPattern( const StringView& input, - const StringView& fixedPattern) { + const std::string& unescapedPattern) { return ( - std::string_view(input).find(std::string_view(fixedPattern)) != + std::string_view(input).find(std::string_view(unescapedPattern)) != std::string::npos); } template class OptimizedLike final : public VectorFunction { public: - OptimizedLike(StringView pattern, vector_size_t reducedPatternLength) + OptimizedLike(std::string pattern, vector_size_t reducedPatternLength) : pattern_{pattern}, reducedPatternLength_{reducedPatternLength} {} static bool match( const StringView& input, - const StringView& pattern, + const std::string& pattern, vector_size_t reducedPatternLength) { switch (P) { case PatternKind::kExactlyN: @@ -483,7 +483,7 @@ class OptimizedLike final : public VectorFunction { } private: - StringView pattern_; + std::string pattern_; vector_size_t reducedPatternLength_; }; @@ -592,34 +592,33 @@ class LikeGeneric final : public VectorFunction { auto applyRow = [&](const StringView& input, const StringView& pattern, const std::optional& escapeChar) -> bool { - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - vector_size_t reducedLength = patternMetadata.length; - - switch (patternMetadata.patternKind) { - case PatternKind::kExactlyN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kAtLeastN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kFixed: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kPrefix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSuffix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSubstring: - return OptimizedLike::match( - input, StringView(patternMetadata.fixedPattern), reducedLength); - default: - return applyWithRegex(input, pattern, escapeChar); - } + PatternMetadata patternMetadata = + determinePatternKind(pattern, escapeChar); + vector_size_t reducedLength = patternMetadata.length; + auto unescapedPattern = patternMetadata.unescapedPattern; + + switch (patternMetadata.patternKind) { + case PatternKind::kExactlyN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kAtLeastN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kFixed: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kPrefix: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kSuffix: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + case PatternKind::kSubstring: + return OptimizedLike::match( + input, unescapedPattern, reducedLength); + default: + return applyWithRegex(input, pattern, escapeChar); } - return applyWithRegex(input, pattern, escapeChar); }; context.ensureWritable(rows, type, localResult); @@ -970,10 +969,62 @@ std::vector> re2ExtractSignatures() { }; } -PatternMetadata determinePatternKind(StringView pattern) { +std::string unescape( + StringView pattern, + vector_size_t start, + vector_size_t end, + std::optional escapeChar) { + if (!escapeChar) { + return std::string(pattern.data(), start, end - start); + } + + std::ostringstream str; + bool isPreviousEscapeChar = false; + for (auto index = start; index < end; index++) { + auto current = pattern.data()[index]; + if (isPreviousEscapeChar) { + VELOX_USER_CHECK( + current == escapeChar || current == '_' || current == '%', + "Escape character must be followed by '%', '_' or the escape character itself") + str << current; + isPreviousEscapeChar = false; + } else { + if (current == escapeChar) { + isPreviousEscapeChar = true; + continue; + } else { + str << current; + } + } + } + + return str.str(); +} + +std::string unescape(StringView pattern, std::optional escapeChar) { + return unescape(pattern, 0, pattern.size(), escapeChar); +} + +bool PatternStringIterator::hasNext() { + return currentIndex_ < (int32_t)(pattern_.size() - 1); +} + +void PatternStringIterator::next() { + currentIndex_++; + previousState_ = state_; + + auto currentChar = pattern_.data()[currentIndex_]; + state_.isEscaping = !previousState_.isEscaping && currentChar == escapeChar_; + state_.isWildcard = !previousState_.isEscaping && + (currentChar == '_' || currentChar == '%') && currentChar != escapeChar_; +} + +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar) { vector_size_t patternLength = pattern.size(); - vector_size_t i = 0; - // Index of the first % or _ character. + + // Index of the first % or _ character(not escaped). vector_size_t wildcardStart = -1; // Count of wildcard character sequences in pattern. vector_size_t numWildcardSequences = 0; @@ -981,40 +1032,80 @@ PatternMetadata determinePatternKind(StringView pattern) { vector_size_t fixedPatternStart = -1; // Index of the last character in the fixed pattern, used to retrieve the // fixed string for patterns of type kSubstring. - vector_size_t fixedPatternEnd = 0; + vector_size_t fixedPatternEnd = -1; // Total number of % characters. vector_size_t anyCharacterWildcardCount = 0; // Total number of _ characters. vector_size_t singleCharacterWildcardCount = 0; - auto patternStr = pattern.data(); - while (i < patternLength) { - if (patternStr[i] == '%' || patternStr[i] == '_') { - if (wildcardStart == -1) { - wildcardStart = i; + PatternStringIterator iterator{pattern, escapeChar}; + bool fallbackToGeneric = false; + + // Return false if failed to handle the new fixed pattern(need to fallback), + // true otherwise. + auto handleNewFixedPattern = [&]() -> bool { + // Record the first fixed pattern start. + if (fixedPatternStart == -1) { + fixedPatternStart = iterator.currentIndex(); + } else { + // This is not the first fixed pattern, not supported, so fallback. + if (iterator.previousState().isWildcard) { + fallbackToGeneric = true; + return false; } - numWildcardSequences++; - // Look till the last contiguous wildcard character, starting from this - // index, is found, or the end of pattern is reached. - while (i < patternLength && - (patternStr[i] == '%' || patternStr[i] == '_')) { - singleCharacterWildcardCount += (patternStr[i] == '_'); - anyCharacterWildcardCount += (patternStr[i] == '%'); - i++; + } + + return true; + }; + + while (iterator.hasNext()) { + iterator.next(); + auto current = iterator.current(); + + if (iterator.previousState().isEscaping) { + // The char follows escapeChar can only be one of (%, _, escapeChar). + if (current != '%' && current != '_' && current != escapeChar) { + fallbackToGeneric = true; + break; + } + + if (!handleNewFixedPattern()) { + break; } } else { - // Ensure that pattern has a single fixed pattern. - if (fixedPatternStart != -1) { - return PatternMetadata{PatternKind::kGeneric, 0}; + // Escaping char, continue. + if (iterator.state().isEscaping) { + continue; } - // Look till the end of fixed pattern, starting from this index, is found, - // or the end of pattern is reached. - fixedPatternStart = i; - while (i < patternLength && - (patternStr[i] != '%' && patternStr[i] != '_')) { - i++; + + if (iterator.state().isWildcard) { + if (wildcardStart == -1) { + wildcardStart = iterator.currentIndex(); + } + + singleCharacterWildcardCount += (iterator.current() == '_'); + anyCharacterWildcardCount += (iterator.current() == '%'); + numWildcardSequences += (!iterator.previousState().isWildcard ? 1 : 0); + + // Mark the end of the fixed pattern. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; + } + } else { + if (!handleNewFixedPattern()) { + break; + } } - fixedPatternEnd = i - 1; + } + } + + if (fallbackToGeneric) { + return PatternMetadata{PatternKind::kGeneric, 0}; + } else { + // The pattern end may not been marked if there is no wildcard char after + // pattern start, so we mark it here. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; } } @@ -1031,8 +1122,13 @@ PatternMetadata determinePatternKind(StringView pattern) { // At this point pattern contains exactly one fixed pattern. // Pattern contains no wildcard characters (is a fixed pattern). if (wildcardStart == -1) { - return PatternMetadata{PatternKind::kFixed, patternLength}; + auto unescapedPattern = unescape(pattern, 0, patternLength, escapeChar); + return PatternMetadata{ + PatternKind::kFixed, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } + // Pattern is generic if it has '_' wildcard characters and a fixed pattern. if (singleCharacterWildcardCount) { return PatternMetadata{PatternKind::kGeneric, 0}; @@ -1040,19 +1136,28 @@ PatternMetadata determinePatternKind(StringView pattern) { // Classify pattern as prefix, fixed center, or suffix pattern based on the // position and count of the wildcard character sequence and fixed pattern. if (fixedPatternStart < wildcardStart) { - return PatternMetadata{PatternKind::kPrefix, wildcardStart}; + auto unescapedPattern = unescape(pattern, 0, wildcardStart, escapeChar); + return PatternMetadata{ + PatternKind::kPrefix, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } // if numWildcardSequences > 1, then fixed pattern must be in between them. if (numWildcardSequences == 2) { + auto unescapedPattern = + unescape(pattern, fixedPatternStart, fixedPatternEnd + 1, escapeChar); return PatternMetadata{ PatternKind::kSubstring, - 0, - std::string( - pattern.data() + fixedPatternStart, - fixedPatternEnd + 1 - fixedPatternStart)}; + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } + auto unescapedPattern = + unescape(pattern, fixedPatternStart, patternLength, escapeChar); + return PatternMetadata{ - PatternKind::kSuffix, patternLength - fixedPatternStart}; + PatternKind::kSuffix, + (vector_size_t)unescapedPattern.size(), + unescapedPattern}; } std::shared_ptr makeLike( @@ -1095,34 +1200,33 @@ std::shared_ptr makeLike( } auto pattern = constantPattern->as>()->valueAt(0); - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - PatternKind patternKind = patternMetadata.patternKind; - vector_size_t reducedLength = patternMetadata.length; - - switch (patternKind) { - case PatternKind::kExactlyN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kAtLeastN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kFixed: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kPrefix: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kSuffix: - return std::make_shared>( - pattern, reducedLength); - default: - - return std::make_shared(pattern, escapeChar); - } + PatternMetadata patternMetadata = determinePatternKind(pattern, escapeChar); + PatternKind patternKind = patternMetadata.patternKind; + vector_size_t reducedLength = patternMetadata.length; + auto unescapedPattern = patternMetadata.unescapedPattern; + + switch (patternKind) { + case PatternKind::kExactlyN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kAtLeastN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kFixed: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kPrefix: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kSuffix: + return std::make_shared>( + unescapedPattern, reducedLength); + case PatternKind::kSubstring: + return std::make_shared>( + unescapedPattern, reducedLength); + default: + return std::make_shared(pattern, escapeChar); } - - return std::make_shared(pattern, escapeChar); } std::vector> likeSignatures() { diff --git a/velox/functions/lib/Re2Functions.h b/velox/functions/lib/Re2Functions.h index e7d2adcacfd3..fd9bd93797f8 100644 --- a/velox/functions/lib/Re2Functions.h +++ b/velox/functions/lib/Re2Functions.h @@ -48,12 +48,14 @@ enum class PatternKind { struct PatternMetadata { PatternKind patternKind; - // Contains the length of the fixed pattern for patterns of kind kFixed, - // kPrefix, and kSuffix. Contains the count of wildcard character '_' for - // patterns of kind kExactlyN and kAtLeastN. Contains 0 otherwise. + // Contains the length of the unescaped fixed pattern for patterns of kind + // kFixed, kPrefix, kSuffix and kSubstring. Contains the count of wildcard + // character '_' for patterns of kind kExactlyN and kAtLeastN. Contains 0 + // otherwise. vector_size_t length; - // Contains the fixed pattern in patterns of kind kSubstring. - std::string fixedPattern = ""; + // Contains the unescaped fixed pattern in patterns of kind kFixed, kPrefix, + // kSuffix and kSubstring. + std::string unescapedPattern = ""; }; /// The functions in this file use RE2 as the regex engine. RE2 is fast, but @@ -114,7 +116,68 @@ std::vector> re2ExtractSignatures(); /// prefix, and suffix patterns. Return the pair {pattern kind, number of '_' /// characters} for patterns with wildcard characters only. Return /// {kGenericPattern, 0} for generic patterns). -PatternMetadata determinePatternKind(StringView pattern); +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar); + +/// Return the unescaped string for the specified string range, if escape char +/// is not specified just return the corresponding substring. +std::string unescape( + StringView pattern, + vector_size_t start, + vector_size_t end, + std::optional escapeChar); +std::string unescape(StringView pattern, std::optional escapeChar); + +/// An Iterator that provides methods(hasNext, next) to iterate through a +/// pattern string. +class PatternStringIterator { + public: + struct State { + // Is current char the escape char? + // NOTE: If escape char is set as '\', for pattern '\\', the first '\' is + // an escaping char, the second is not, it is just a literal '\' + bool isEscaping = false; + // Is current char the wildcard char? + // NOTE: If escape char is set as '\', for pattern '\%%', the first '%' is + // not a wildcard, just a literal '%', the second '%' is a wildcard. + bool isWildcard = false; + }; + + explicit PatternStringIterator( + StringView pattern, + std::optional escapeChar) + : pattern_(pattern), escapeChar_(escapeChar) {} + + bool hasNext(); + void next(); + + char currentIndex() { + return currentIndex_; + } + + char current() { + return pattern_.data()[currentIndex_]; + } + + State state() { + return state_; + } + + State previousState() { + return previousState_; + } + + private: + StringView pattern_; + std::optional escapeChar_; + + int32_t currentIndex_ = -1; + // State of current char. + State state_; + // State of previous char. + State previousState_; +}; std::shared_ptr makeLike( const std::string& name, diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 57376d3f44e6..accc491bc29d 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -460,15 +460,33 @@ TEST_F(Re2FunctionsTest, likePattern) { testLike("abc", "MEDIUM POLISHED%", false); } +TEST_F(Re2FunctionsTest, unescape) { + EXPECT_EQ("%", unescape(R"(\%)", std::make_optional('\\'))); + EXPECT_EQ("%%%", unescape(R"(\%\%\%)", std::make_optional('\\'))); + EXPECT_EQ("a%b%c%d", unescape(R"(a\%b\%c\%d)", std::make_optional('\\'))); + EXPECT_EQ("a_b_c", unescape(R"(a\_b\_c)", std::make_optional('\\'))); + EXPECT_EQ("a_b_c", unescape(R"(%%a\_b\_c)", 2, 9, std::make_optional('\\'))); + + EXPECT_EQ("%%%", unescape(R"(%%%%%%)", std::make_optional('%'))); +} + TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { - auto testPattern = [&](StringView pattern, - PatternKind patternKind, - vector_size_t length, - StringView fixedPattern = "") { - PatternMetadata patternMetadata = determinePatternKind(pattern); + auto testPattern = + [&](StringView pattern, PatternKind patternKind, vector_size_t length) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, length); + }; + + auto testPatternString = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); EXPECT_EQ(patternMetadata.patternKind, patternKind); - EXPECT_EQ(patternMetadata.length, length); - EXPECT_EQ(patternMetadata.fixedPattern, fixedPattern); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); + EXPECT_EQ(patternMetadata.unescapedPattern, fixedPattern); }; testPattern("_", PatternKind::kExactlyN, 1); @@ -477,12 +495,14 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%", PatternKind::kAtLeastN, 0); testPattern("__%%__", PatternKind::kAtLeastN, 4); testPattern("%_%%", PatternKind::kAtLeastN, 1); + testPattern("%%%%%%%%%%%%", PatternKind::kAtLeastN, 0); - testPattern("presto", PatternKind::kFixed, 6); - testPattern("hello", PatternKind::kFixed, 5); - testPattern("a", PatternKind::kFixed, 1); - testPattern("helloPrestoWorld", PatternKind::kFixed, 16); - testPattern("aBcD", PatternKind::kFixed, 4); + testPatternString("presto", PatternKind::kFixed, "presto"); + testPatternString("hello", PatternKind::kFixed, "hello"); + testPatternString("a", PatternKind::kFixed, "a"); + testPatternString( + "helloPrestoWorld", PatternKind::kFixed, "helloPrestoWorld"); + testPatternString("aBcD", PatternKind::kFixed, "aBcD"); testPattern("presto%", PatternKind::kPrefix, 6); testPattern("hello%%", PatternKind::kPrefix, 5); @@ -496,11 +516,11 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%helloPrestoWorld", PatternKind::kSuffix, 16); testPattern("%aBcD", PatternKind::kSuffix, 4); - testPattern("%presto%%", PatternKind::kSubstring, 0, "presto"); - testPattern("%%hello%", PatternKind::kSubstring, 0, "hello"); - testPattern("%%%aAb\n%", PatternKind::kSubstring, 0, "aAb\n"); - testPattern( - "%helloPrestoWorld%%%", PatternKind::kSubstring, 0, "helloPrestoWorld"); + testPatternString("%presto%%", PatternKind::kSubstring, "presto"); + testPatternString("%%hello%", PatternKind::kSubstring, "hello"); + testPatternString("%%%aAb\n%", PatternKind::kSubstring, "aAb\n"); + testPatternString( + "%helloPrestoWorld%%%", PatternKind::kSubstring, "helloPrestoWorld"); testPattern("_b%%__", PatternKind::kGeneric, 0); testPattern("%_%p", PatternKind::kGeneric, 0); @@ -515,6 +535,36 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("_aBcD", PatternKind::kGeneric, 0); } +TEST_F(Re2FunctionsTest, likeDeterminePatternKindWithEscapeChar) { + auto testPattern = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = determinePatternKind(pattern, '\\'); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); + EXPECT_EQ(patternMetadata.unescapedPattern, fixedPattern); + }; + + testPattern(R"(\_)", PatternKind::kFixed, "_"); + testPattern(R"(\_\_\_\_)", PatternKind::kFixed, "____"); + testPattern(R"(a\_\_b\_\_c)", PatternKind::kFixed, "a__b__c"); + + testPattern(R"(\%)", PatternKind::kFixed, "%"); + testPattern(R"(\%\%\%)", PatternKind::kFixed, "%%%"); + testPattern(R"(a\%b\%c\%d)", PatternKind::kFixed, "a%b%c%d"); + + testPattern(R"(\_\_%%)", PatternKind::kPrefix, "__"); + testPattern(R"(a\_b\_c%%)", PatternKind::kPrefix, "a_b_c"); + + testPattern(R"(%%\_\_)", PatternKind::kSuffix, "__"); + testPattern(R"(%%a\_b\_c)", PatternKind::kSuffix, "a_b_c"); + testPattern(R"(%\_\%)", PatternKind::kSuffix, "_%"); + + testPattern(R"(%\_%%)", PatternKind::kSubstring, "_"); + testPattern(R"(%\_\%%%)", PatternKind::kSubstring, "_%"); + testPattern(R"(%\_ab\%%%)", PatternKind::kSubstring, "_ab%"); +} + TEST_F(Re2FunctionsTest, likePatternWildcard) { testLike("", "", true); testLike("", "%", true); @@ -573,6 +623,15 @@ TEST_F(Re2FunctionsTest, likePatternFixed) { testLike("\nabcd\n", "\nabc\nd\n", false); testLike("\nab\tcd\b", "\nabcd\b", false); + // Test literal '_' & '%' in pattern. + testLike("a", R"(\_)", '\\', false); + testLike("_b", R"(\_b)", '\\', true); + testLike("abc_d", R"(abc\_d)", '\\', true); + + testLike("a", R"(\%)", '\\', false); + testLike("abc%d", R"(abc\%d)", '\\', true); + testLike("abc%d", R"(a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input, true); } @@ -605,6 +664,15 @@ TEST_F(Re2FunctionsTest, likePatternPrefix) { testLike("\nabc\nde\n", "ab\nc%", false); testLike("\nabc\nde\n", "abc%", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(\_%)", '\\', true); + testLike("_bcd", R"(\_b%)", '\\', true); + testLike("abc_defg", R"(abc\_d%)", '\\', true); + + testLike("%ab", R"(\%%)", '\\', true); + testLike("abc%defg", R"(abc\%d%)", '\\', true); + testLike("abc%defg", R"(a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input + generateString(kAnyWildcardCharacter), true); } @@ -637,6 +705,15 @@ TEST_F(Re2FunctionsTest, likePatternSuffix) { testLike("\nabcde\n", "%d\n", false); testLike("\nabcde\n", "%e_\n", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(%\_)", '\\', true); + testLike("cd_b", R"(%\_b)", '\\', true); + testLike("efgabc_d", R"(%abc\_d)", '\\', true); + + testLike("ab%", R"(%\%)", '\\', true); + testLike("efgabc%d", R"(%abc\%d)", '\\', true); + testLike("abc%defg", R"(%a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike(input, generateString(kAnyWildcardCharacter) + input, true); } @@ -668,6 +745,14 @@ TEST_F(Re2FunctionsTest, likeSubstringPattern) { testLike("\nabcde\n", "%%d\n%", false); testLike("\nabcde\n", "%%e_\n%%", false); + // Test literal '_' & '%' in pattern. + testLike("cd_be", R"(%\_b%)", '\\', true); + testLike("efgabc_dhi", R"(%abc\_d%)", '\\', true); + + testLike("ab%cd", R"(%\%%)", '\\', true); + testLike("efgabc%dhi", R"(%abc\%d%)", '\\', true); + testLike("abc%defg", R"(%a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike( input, @@ -1045,12 +1130,13 @@ TEST_F(Re2FunctionsTest, tryException) { // Make sure we do not compile more than kMaxCompiledRegexes. TEST_F(Re2FunctionsTest, likeRegexLimit) { - VectorPtr pattern = makeFlatVector(26); - VectorPtr input = makeFlatVector(26); + int count = 26; + VectorPtr pattern = makeFlatVector(count); + VectorPtr input = makeFlatVector(count); VectorPtr result; auto flatInput = input->asFlatVector(); - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { flatInput->set(i, ""); } @@ -1077,14 +1163,14 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { auto verifyNoRegexCompilationForPattern = [&](PatternKind patternKind) { // Over 20 all optimized, will pass. - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { std::string patternAtIdx = getPatternAtIdx(patternKind, i); flatPattern->set(i, StringView(patternAtIdx)); } result = evaluate("like(c0 , c1)", makeRowVector({input, pattern})); // Pattern '%%%', of type kAtleastN, matches with empty input. assertEqualVectors( - makeConstant((patternKind == PatternKind::kAtLeastN), 26), result); + makeConstant((patternKind == PatternKind::kAtLeastN), count), result); }; // Infer regex compilation does not happen for optimized patterns by verifying @@ -1150,5 +1236,56 @@ TEST_F(Re2FunctionsTest, invalidEscapeChar) { } } +TEST_F(Re2FunctionsTest, patternStringIteratorHasNext) { + std::string pattern = "_"; + PatternStringIterator iterator{StringView{pattern}, '#'}; + + EXPECT_EQ(true, iterator.hasNext()); + + iterator.next(); + EXPECT_EQ(false, iterator.hasNext()); +} + +TEST_F(Re2FunctionsTest, patternStringIteratorIsEscaping) { + std::string pattern = "####"; + PatternStringIterator iterator{StringView{pattern}, '#'}; + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(true, iterator.previousState().isEscaping); + EXPECT_EQ(false, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(false, iterator.previousState().isEscaping); + EXPECT_EQ(true, iterator.state().isEscaping); + + iterator.next(); + EXPECT_EQ(true, iterator.previousState().isEscaping); + EXPECT_EQ(false, iterator.state().isEscaping); +} + +TEST_F(Re2FunctionsTest, patternStringIteratorIsWildcard) { + std::string pattern = "%%%%"; + PatternStringIterator iterator{StringView{pattern}, '%'}; + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(false, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(true, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); + + iterator.next(); + EXPECT_EQ(false, iterator.state().isEscaping); + EXPECT_EQ(false, iterator.state().isWildcard); +} + } // namespace } // namespace facebook::velox::functions