diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index 39c3f79e8b7a..41bcbb700505 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -1024,15 +1024,13 @@ std::string unescape( class PatternStringIterator { public: PatternStringIterator(StringView pattern, std::optional escapeChar) - : pattern_(pattern), - escapeChar_(escapeChar), - lastIndex_{pattern_.size() - 1} {} + : pattern_(pattern), escapeChar_(escapeChar) {} // Advance the cursor to next char, escape char is automatically handled. // Return true if the cursor is advanced successfully, false otherwise(reached // the end of the pattern string). bool next() { - if (currentIndex_ == lastIndex_) { + if (nextStart_ == pattern_.size()) { return false; } @@ -1040,19 +1038,18 @@ class PatternStringIterator { (charKind_ == CharKind::kSingleCharWildcard || charKind_ == CharKind::kAnyCharsWildcard); - currentIndex_++; - auto currentChar = current(); + currentStart_ = nextStart_; + auto currentChar = charAt(currentStart_); if (currentChar == escapeChar_) { // Escape char should be followed by another char. VELOX_USER_CHECK_LT( - currentIndex_, - lastIndex_, + currentStart_ + 1, + pattern_.size(), "Escape character must be followed by '%', '_' or the escape character itself: {}, escape {}", pattern_, escapeChar_.value()) - currentIndex_++; - currentChar = current(); + currentChar = charAt(currentStart_ + 1); // The char follows escapeChar can only be one of (%, _, escapeChar). if (currentChar == escapeChar_ || currentChar == '_' || currentChar == '%') { @@ -1063,20 +1060,25 @@ class PatternStringIterator { pattern_, escapeChar_.value()) } - } else if (currentChar == '_') { - charKind_ = CharKind::kSingleCharWildcard; - } else if (currentChar == '%') { - charKind_ = CharKind::kAnyCharsWildcard; + // One escape char plus the current char. + nextStart_ = currentStart_ + 2; } else { - charKind_ = CharKind::kNormal; + if (currentChar == '_') { + charKind_ = CharKind::kSingleCharWildcard; + } else if (currentChar == '%') { + charKind_ = CharKind::kAnyCharsWildcard; + } else { + charKind_ = CharKind::kNormal; + } + nextStart_ = currentStart_ + 1; } return true; } - // Current index of the cursor. - char currentIndex() const { - return currentIndex_; + // Start index of the current character. + size_t currentStart() const { + return currentStart_; } bool isAnyCharsWildcard() const { @@ -1111,15 +1113,16 @@ class PatternStringIterator { }; // Char at current cursor. - char current() const { - return pattern_.data()[currentIndex_]; + char charAt(size_t index) const { + VELOX_DCHECK(index >= 0 && index < pattern_.size()) + return pattern_.data()[index]; } const StringView pattern_; const std::optional escapeChar_; - const size_t lastIndex_; - int32_t currentIndex_{-1}; + size_t currentStart_{0}; + size_t nextStart_{0}; CharKind charKind_{CharKind::kNormal}; bool isPreviousWildcard_{false}; }; @@ -1127,7 +1130,7 @@ class PatternStringIterator { PatternMetadata determinePatternKind( StringView pattern, std::optional escapeChar) { - int32_t patternLength = pattern.size(); + const size_t patternLength = pattern.size(); // Index of the first % or _ character(not escaped). int32_t wildcardStart = -1; @@ -1148,9 +1151,10 @@ PatternMetadata determinePatternKind( // Iterate through the pattern string to collect the stats for the simple // patterns that we can optimize. while (iterator.next()) { + const size_t currentStart = iterator.currentStart(); if (iterator.isWildcard()) { if (wildcardStart == -1) { - wildcardStart = iterator.currentIndex(); + wildcardStart = currentStart; } if (iterator.isSingleCharWildcard()) { @@ -1165,12 +1169,12 @@ PatternMetadata determinePatternKind( // Mark the end of the fixed pattern. if (fixedPatternStart != -1 && fixedPatternEnd == -1) { - fixedPatternEnd = iterator.currentIndex() - 1; + fixedPatternEnd = currentStart - 1; } } else { // Record the first fixed pattern start. if (fixedPatternStart == -1) { - fixedPatternStart = iterator.currentIndex(); + fixedPatternStart = currentStart; } else { // This is not the first fixed pattern, not supported, so fallback. if (iterator.isPreviousWildcard()) { @@ -1183,7 +1187,7 @@ PatternMetadata determinePatternKind( // The pattern end may not been marked if there is no wildcard char after // pattern start, so we mark it here. if (fixedPatternStart != -1 && fixedPatternEnd == -1) { - fixedPatternEnd = iterator.currentIndex() - 1; + fixedPatternEnd = patternLength - 1; } // At this point pattern has max of one fixed pattern. diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index fc50fa176f97..d85eb6f90905 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -591,6 +591,18 @@ TEST_F(Re2FunctionsTest, likePatternWildcard) { testLike("\nabcde\n", "%bcf%", false); } +TEST_F(Re2FunctionsTest, likePatternEscapingEscapeChar) { + testLike(R"(\)", R"(\\)", '\\', true); + testLike(R"(\abc)", R"(\\%)", '\\', true); + testLike(R"(\abc)", R"(\\abc)", '\\', true); + testLike(R"(abc\abc)", R"(abc\\abc)", '\\', true); + testLike(R"(\abcdef)", R"(\\abc%)", '\\', true); + testLike(R"(\abcdefghijkl)", R"(\\abc%gh%)", '\\', true); + testLike(R"(abc\abc)", R"(%\\%)", '\\', true); + testLike(R"(abcdef\abcdef)", R"(%\\abc%)", '\\', true); + testLike(R"(abcdef\\\abcdef)", R"(%\\\\\\abc%)", '\\', true); +} + TEST_F(Re2FunctionsTest, likePatternFixed) { testLike("", "", true); testLike("abcde", "abcde", true);