From 1779e82dfa3f41fb4c066fa184a4c1a228ee376a Mon Sep 17 00:00:00 2001 From: xumingming Date: Wed, 13 Dec 2023 11:31:26 -0800 Subject: [PATCH] Optimize LIKE with custom escape char (#7730) Summary: Currently we optimize LIKE operation only if escape char is not specified, this PR adds the ability to apply the optimization even if user specifies escape char. We introduced a PatternStringIterator which handles escaping transparently, so existing optimizations(kPrefix, kSuffix, kSubstring etc) now work for patterns with escape char transparently, and future optimizations will have effect for escaped pattern transparently too. The benchmark result before this optimization: ``` ============================================================================ [...]hmarks/ExpressionBenchmarkBuilder.cpp relative time/iter iters/s ============================================================================ like_generic##like_generic 4.14s 241.44m ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- like_prefix##like_prefix 1.20s 833.70m like_prefix##starts_with 2.92ms 342.44 like_substring##like_substring 4.22s 236.77m like_substring##strpos 6.98ms 143.27 like_suffix##like_suffix 3.09s 323.90m like_suffix##ends_with 3.02ms 331.11 ``` After: ``` ============================================================================ [...]hmarks/ExpressionBenchmarkBuilder.cpp relative time/iter iters/s ============================================================================ like_generic##like_generic 3.86s 258.97m ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- like_prefix##like_prefix 4.18ms 239.24 like_prefix##starts_with 2.76ms 362.05 like_substring##like_substring 7.71ms 129.75 like_substring##strpos 6.67ms 149.90 like_suffix##like_suffix 4.20ms 237.85 like_suffix##ends_with 2.90ms 344.93 ``` In Summary: - Speedup of kSubstring is about 500x. - Speedup of kPrefix is about 250x. - Speedup of kSuffix is about 700x. Why the speedup is so huge? There are two reasons: - Re2 is really slow compare to the optimizations we made, even if the input string is short(10 byte), Re2 is 100x slower than our optimizations. - When the input strings get longer(10bytes -> 1000bytes), the performance of our optimizations does not change much, but Re2's performance will be 10x slower. And we can confirm the speedup is reasonable from the comparison between our optmizations and the simple scalar function strpos, starts_with, ends_with, the performance numbers are quite close(see the like##strpos/starts_with/ends_with in the benchmark result for more details). Pull Request resolved: https://github.com/facebookincubator/velox/pull/7730 Reviewed By: pedroerp Differential Revision: D52077250 Pulled By: mbasmanova fbshipit-source-id: 39703ddcc7f4f2044460d93866670f730e139120 --- velox/benchmarks/basic/CMakeLists.txt | 9 +- velox/benchmarks/basic/LikeBenchmark.cpp | 96 +++++ ...onsBenchmark.cpp => LikeTpchBenchmark.cpp} | 0 velox/functions/lib/Re2Functions.cpp | 403 +++++++++++++----- velox/functions/lib/Re2Functions.h | 16 +- .../functions/lib/tests/Re2FunctionsTest.cpp | 118 ++++- 6 files changed, 503 insertions(+), 139 deletions(-) create mode 100644 velox/benchmarks/basic/LikeBenchmark.cpp rename velox/benchmarks/basic/{LikeFunctionsBenchmark.cpp => LikeTpchBenchmark.cpp} (100%) diff --git a/velox/benchmarks/basic/CMakeLists.txt b/velox/benchmarks/basic/CMakeLists.txt index 427a2db7ea9a..dc7ba99d7b99 100644 --- a/velox/benchmarks/basic/CMakeLists.txt +++ b/velox/benchmarks/basic/CMakeLists.txt @@ -63,10 +63,15 @@ add_executable(velox_benchmark_basic_preproc Preproc.cpp) target_link_libraries(velox_benchmark_basic_preproc ${velox_benchmark_deps} velox_functions_prestosql velox_vector_test_lib) -add_executable(velox_like_functions_benchmark LikeFunctionsBenchmark.cpp) -target_link_libraries(velox_like_functions_benchmark ${velox_benchmark_deps} +add_executable(velox_like_tpch_benchmark LikeTpchBenchmark.cpp) +target_link_libraries(velox_like_tpch_benchmark ${velox_benchmark_deps} velox_functions_lib velox_tpch_gen velox_vector_test_lib) +add_executable(velox_like_benchmark LikeBenchmark.cpp) +target_link_libraries( + velox_like_benchmark ${velox_benchmark_deps} velox_functions_lib + velox_functions_prestosql velox_vector_test_lib) + add_executable(velox_benchmark_basic_vector_fuzzer VectorFuzzer.cpp) target_link_libraries(velox_benchmark_basic_vector_fuzzer ${velox_benchmark_deps} velox_vector_test_lib) diff --git a/velox/benchmarks/basic/LikeBenchmark.cpp b/velox/benchmarks/basic/LikeBenchmark.cpp new file mode 100644 index 000000000000..553090191b1c --- /dev/null +++ b/velox/benchmarks/basic/LikeBenchmark.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/benchmarks/ExpressionBenchmarkBuilder.h" +#include "velox/functions/lib/Re2Functions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +using namespace facebook; +using namespace facebook::velox; +using namespace facebook::velox::functions; +using namespace facebook::velox::functions::test; +using namespace facebook::velox::memory; +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + + exec::registerStatefulVectorFunction("like", likeSignatures(), makeLike); + // Register the scalar functions. + prestosql::registerAllScalarFunctions(""); + + // exec::register + ExpressionBenchmarkBuilder benchmarkBuilder; + const vector_size_t vectorSize = 1000; + auto vectorMaker = benchmarkBuilder.vectorMaker(); + + auto makeInput = + [&](vector_size_t vectorSize, bool padAtHead, bool padAtTail) { + return vectorMaker.flatVector(vectorSize, [&](auto row) { + // Strings in even rows contain/start with/end with a_b_c depends on + // value of padAtHead && padAtTail. + if (row % 2 == 0) { + auto padding = std::string(row / 2 + 1, 'x'); + if (padAtHead && padAtTail) { + return fmt::format("{}a_b_c{}", padding, padding); + } else if (padAtHead) { + return fmt::format("{}a_b_c", padding); + } else if (padAtTail) { + return fmt::format("a_b_c{}", padding); + } else { + return std::string("a_b_c"); + } + } else { + return std::string(row, 'x'); + } + }); + }; + + auto substringInput = makeInput(vectorSize, true, true); + auto prefixInput = makeInput(vectorSize, false, true); + auto suffixInput = makeInput(vectorSize, true, false); + + benchmarkBuilder + .addBenchmarkSet( + "like_substring", vectorMaker.rowVector({"col0"}, {substringInput})) + .addExpression("like_substring", R"(like(col0, '%a\_b\_c%', '\'))") + .addExpression("strpos", R"(strpos(col0, 'a_b_c') > 0)"); + + benchmarkBuilder + .addBenchmarkSet( + "like_prefix", vectorMaker.rowVector({"col0"}, {prefixInput})) + .addExpression("like_prefix", R"(like(col0, 'a\_b\_c%', '\'))") + .addExpression("starts_with", R"(starts_with(col0, 'a_b_c'))"); + + benchmarkBuilder + .addBenchmarkSet( + "like_suffix", vectorMaker.rowVector({"col0"}, {suffixInput})) + .addExpression("like_suffix", R"(like(col0, '%a\_b\_c', '\'))") + .addExpression("ends_with", R"(ends_with(col0, 'a_b_c'))"); + + benchmarkBuilder + .addBenchmarkSet( + "like_generic", vectorMaker.rowVector({"col0"}, {substringInput})) + .addExpression("like_generic", R"(like(col0, '%a%b%c'))"); + + benchmarkBuilder.registerBenchmarks(); + benchmarkBuilder.testBenchmarks(); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/benchmarks/basic/LikeFunctionsBenchmark.cpp b/velox/benchmarks/basic/LikeTpchBenchmark.cpp similarity index 100% rename from velox/benchmarks/basic/LikeFunctionsBenchmark.cpp rename to velox/benchmarks/basic/LikeTpchBenchmark.cpp diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index 1eb52f7c3c0e..39c3f79e8b7a 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -389,8 +389,8 @@ class Re2SearchAndExtract final : public VectorFunction { // Match string 'input' with a fixed pattern (with no wildcard characters). bool matchExactPattern( StringView input, - StringView pattern, - vector_size_t length) { + const std::string& pattern, + size_t length) { return input.size() == pattern.size() && std::memcmp(input.data(), pattern.data(), length) == 0; } @@ -398,8 +398,8 @@ bool matchExactPattern( // Match the first 'length' characters of string 'input' and prefix pattern. bool matchPrefixPattern( StringView input, - StringView pattern, - vector_size_t length) { + const std::string& pattern, + size_t length) { return input.size() >= length && std::memcmp(input.data(), pattern.data(), length) == 0; } @@ -407,8 +407,8 @@ bool matchPrefixPattern( // Match the last 'length' characters of string 'input' and suffix pattern. bool matchSuffixPattern( StringView input, - StringView pattern, - vector_size_t length) { + const std::string& pattern, + size_t length) { return input.size() >= length && std::memcmp( input.data() + input.size() - length, @@ -418,7 +418,7 @@ bool matchSuffixPattern( bool matchSubstringPattern( const StringView& input, - const StringView& fixedPattern) { + const std::string& fixedPattern) { return ( std::string_view(input).find(std::string_view(fixedPattern)) != std::string::npos); @@ -427,13 +427,14 @@ bool matchSubstringPattern( template class OptimizedLike final : public VectorFunction { public: - OptimizedLike(StringView pattern, vector_size_t reducedPatternLength) - : pattern_{pattern}, reducedPatternLength_{reducedPatternLength} {} + OptimizedLike(std::string pattern, size_t reducedPatternLength) + : pattern_{std::move(pattern)}, + reducedPatternLength_{reducedPatternLength} {} static bool match( const StringView& input, - const StringView& pattern, - vector_size_t reducedPatternLength) { + const std::string& pattern, + size_t reducedPatternLength) { switch (P) { case PatternKind::kExactlyN: return input.size() == reducedPatternLength; @@ -483,8 +484,8 @@ class OptimizedLike final : public VectorFunction { } private: - StringView pattern_; - vector_size_t reducedPatternLength_; + const std::string pattern_; + const size_t reducedPatternLength_; }; // This function is used when pattern and escape are constants. And there is not @@ -592,34 +593,33 @@ class LikeGeneric final : public VectorFunction { auto applyRow = [&](const StringView& input, const StringView& pattern, const std::optional& escapeChar) -> bool { - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - vector_size_t reducedLength = patternMetadata.length; - - switch (patternMetadata.patternKind) { - case PatternKind::kExactlyN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kAtLeastN: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kFixed: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kPrefix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSuffix: - return OptimizedLike::match( - input, pattern, reducedLength); - case PatternKind::kSubstring: - return OptimizedLike::match( - input, StringView(patternMetadata.fixedPattern), reducedLength); - default: - return applyWithRegex(input, pattern, escapeChar); - } + PatternMetadata patternMetadata = + determinePatternKind(pattern, escapeChar); + const auto reducedLength = patternMetadata.length; + const auto& fixedPattern = patternMetadata.fixedPattern; + + switch (patternMetadata.patternKind) { + case PatternKind::kExactlyN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kAtLeastN: + return OptimizedLike::match( + input, pattern, reducedLength); + case PatternKind::kFixed: + return OptimizedLike::match( + input, fixedPattern, reducedLength); + case PatternKind::kPrefix: + return OptimizedLike::match( + input, fixedPattern, reducedLength); + case PatternKind::kSuffix: + return OptimizedLike::match( + input, fixedPattern, reducedLength); + case PatternKind::kSubstring: + return OptimizedLike::match( + input, fixedPattern, reducedLength); + default: + return applyWithRegex(input, pattern, escapeChar); } - return applyWithRegex(input, pattern, escapeChar); }; context.ensureWritable(rows, type, localResult); @@ -970,89 +970,267 @@ std::vector> re2ExtractSignatures() { }; } -PatternMetadata determinePatternKind(StringView pattern) { - vector_size_t patternLength = pattern.size(); - vector_size_t i = 0; - // Index of the first % or _ character. - vector_size_t wildcardStart = -1; - // Count of wildcard character sequences in pattern. - vector_size_t numWildcardSequences = 0; +std::string unescape( + StringView pattern, + size_t start, + size_t end, + std::optional escapeChar) { + if (!escapeChar) { + return std::string(pattern.data() + start, end - start); + } + + std::ostringstream os; + auto cursor = pattern.begin() + start; + auto endCursor = pattern.begin() + end; + while (cursor < endCursor) { + auto previous = cursor; + + // Find the next escape char. + cursor = std::find(cursor, endCursor, escapeChar.value()); + if (cursor < endCursor) { + // There are non-escape chars, append them. + if (previous < cursor) { + os.write(previous, cursor - previous); + } + + // Make sure there is a following normal char. + VELOX_USER_CHECK( + cursor + 1 < endCursor, + "Escape character must be followed by '%', '_' or the escape character itself"); + + // Make sure the escaped char is valid. + cursor++; + auto current = *cursor; + VELOX_USER_CHECK( + current == escapeChar || current == '_' || current == '%', + "Escape character must be followed by '%', '_' or the escape character itself"); + + // Append the escaped char. + os << current; + } else { + // Escape char not found, append all the non-escape chars. + os.write(previous, endCursor - previous); + break; + } + + // Advance the cursor. + cursor++; + } + + return os.str(); +} + +// Iterates through a pattern string. Transparently handles escape sequences. +class PatternStringIterator { + public: + PatternStringIterator(StringView pattern, std::optional escapeChar) + : pattern_(pattern), + escapeChar_(escapeChar), + lastIndex_{pattern_.size() - 1} {} + + // Advance the cursor to next char, escape char is automatically handled. + // Return true if the cursor is advanced successfully, false otherwise(reached + // the end of the pattern string). + bool next() { + if (currentIndex_ == lastIndex_) { + return false; + } + + isPreviousWildcard_ = + (charKind_ == CharKind::kSingleCharWildcard || + charKind_ == CharKind::kAnyCharsWildcard); + + currentIndex_++; + auto currentChar = current(); + if (currentChar == escapeChar_) { + // Escape char should be followed by another char. + VELOX_USER_CHECK_LT( + currentIndex_, + lastIndex_, + "Escape character must be followed by '%', '_' or the escape character itself: {}, escape {}", + pattern_, + escapeChar_.value()) + + currentIndex_++; + currentChar = current(); + // The char follows escapeChar can only be one of (%, _, escapeChar). + if (currentChar == escapeChar_ || currentChar == '_' || + currentChar == '%') { + charKind_ = CharKind::kNormal; + } else { + VELOX_USER_FAIL( + "Escape character must be followed by '%', '_' or the escape character itself: {}, escape {}", + pattern_, + escapeChar_.value()) + } + } else if (currentChar == '_') { + charKind_ = CharKind::kSingleCharWildcard; + } else if (currentChar == '%') { + charKind_ = CharKind::kAnyCharsWildcard; + } else { + charKind_ = CharKind::kNormal; + } + + return true; + } + + // Current index of the cursor. + char currentIndex() const { + return currentIndex_; + } + + bool isAnyCharsWildcard() const { + return charKind_ == CharKind::kAnyCharsWildcard; + } + + bool isSingleCharWildcard() const { + return charKind_ == CharKind::kSingleCharWildcard; + } + + bool isWildcard() { + return isAnyCharsWildcard() || isSingleCharWildcard(); + } + + bool isPreviousWildcard() { + return isPreviousWildcard_; + } + + private: + // Represents the state of current cursor/char. + enum class CharKind { + // Wildcard char: %. + // NOTE: If escape char is set as '\', for pattern '\%%', the first '%' is + // not a wildcard, just a literal '%', the second '%' is a wildcard. + kAnyCharsWildcard, + // Wildcard char: _. + // NOTE: If escape char is set as '\', for pattern '\__', the first '_' is + // not a wildcard, just a literal '_', the second '_' is a wildcard. + kSingleCharWildcard, + // Chars that are not escape char & not wildcard char. + kNormal + }; + + // Char at current cursor. + char current() const { + return pattern_.data()[currentIndex_]; + } + + const StringView pattern_; + const std::optional escapeChar_; + const size_t lastIndex_; + + int32_t currentIndex_{-1}; + CharKind charKind_{CharKind::kNormal}; + bool isPreviousWildcard_{false}; +}; + +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar) { + int32_t patternLength = pattern.size(); + + // Index of the first % or _ character(not escaped). + int32_t wildcardStart = -1; // Index of the first character that is not % and not _. - vector_size_t fixedPatternStart = -1; + int32_t fixedPatternStart = -1; // Index of the last character in the fixed pattern, used to retrieve the // fixed string for patterns of type kSubstring. - vector_size_t fixedPatternEnd = 0; + int32_t fixedPatternEnd = -1; + // Count of wildcard character sequences in pattern. + size_t numWildcardSequences = 0; // Total number of % characters. - vector_size_t anyCharacterWildcardCount = 0; + size_t anyCharacterWildcardCount = 0; // Total number of _ characters. - vector_size_t singleCharacterWildcardCount = 0; - auto patternStr = pattern.data(); + size_t singleCharacterWildcardCount = 0; + + PatternStringIterator iterator{pattern, escapeChar}; - while (i < patternLength) { - if (patternStr[i] == '%' || patternStr[i] == '_') { + // Iterate through the pattern string to collect the stats for the simple + // patterns that we can optimize. + while (iterator.next()) { + if (iterator.isWildcard()) { if (wildcardStart == -1) { - wildcardStart = i; + wildcardStart = iterator.currentIndex(); } - numWildcardSequences++; - // Look till the last contiguous wildcard character, starting from this - // index, is found, or the end of pattern is reached. - while (i < patternLength && - (patternStr[i] == '%' || patternStr[i] == '_')) { - singleCharacterWildcardCount += (patternStr[i] == '_'); - anyCharacterWildcardCount += (patternStr[i] == '%'); - i++; + + if (iterator.isSingleCharWildcard()) { + ++singleCharacterWildcardCount; + } else { + ++anyCharacterWildcardCount; } - } else { - // Ensure that pattern has a single fixed pattern. - if (fixedPatternStart != -1) { - return PatternMetadata{PatternKind::kGeneric, 0}; + + if (!iterator.isPreviousWildcard()) { + ++numWildcardSequences; } - // Look till the end of fixed pattern, starting from this index, is found, - // or the end of pattern is reached. - fixedPatternStart = i; - while (i < patternLength && - (patternStr[i] != '%' && patternStr[i] != '_')) { - i++; + + // Mark the end of the fixed pattern. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; + } + } else { + // Record the first fixed pattern start. + if (fixedPatternStart == -1) { + fixedPatternStart = iterator.currentIndex(); + } else { + // This is not the first fixed pattern, not supported, so fallback. + if (iterator.isPreviousWildcard()) { + return PatternMetadata{PatternKind::kGeneric, 0}; + } } - fixedPatternEnd = i - 1; } } + // The pattern end may not been marked if there is no wildcard char after + // pattern start, so we mark it here. + if (fixedPatternStart != -1 && fixedPatternEnd == -1) { + fixedPatternEnd = iterator.currentIndex() - 1; + } + // At this point pattern has max of one fixed pattern. // Pattern contains wildcard characters only. if (fixedPatternStart == -1) { - if (!anyCharacterWildcardCount) { + if (anyCharacterWildcardCount == 0) { return PatternMetadata{ PatternKind::kExactlyN, singleCharacterWildcardCount}; } return PatternMetadata{ PatternKind::kAtLeastN, singleCharacterWildcardCount}; } + // At this point pattern contains exactly one fixed pattern. // Pattern contains no wildcard characters (is a fixed pattern). if (wildcardStart == -1) { - return PatternMetadata{PatternKind::kFixed, patternLength}; + auto fixedPattern = unescape(pattern, 0, patternLength, escapeChar); + return PatternMetadata{ + PatternKind::kFixed, fixedPattern.size(), fixedPattern}; } + // Pattern is generic if it has '_' wildcard characters and a fixed pattern. - if (singleCharacterWildcardCount) { + if (singleCharacterWildcardCount > 0) { return PatternMetadata{PatternKind::kGeneric, 0}; } + // Classify pattern as prefix, fixed center, or suffix pattern based on the // position and count of the wildcard character sequence and fixed pattern. if (fixedPatternStart < wildcardStart) { - return PatternMetadata{PatternKind::kPrefix, wildcardStart}; + auto fixedPattern = unescape(pattern, 0, wildcardStart, escapeChar); + return PatternMetadata{ + PatternKind::kPrefix, fixedPattern.size(), fixedPattern}; } + // if numWildcardSequences > 1, then fixed pattern must be in between them. if (numWildcardSequences == 2) { + auto fixedPattern = + unescape(pattern, fixedPatternStart, fixedPatternEnd + 1, escapeChar); return PatternMetadata{ - PatternKind::kSubstring, - 0, - std::string( - pattern.data() + fixedPatternStart, - fixedPatternEnd + 1 - fixedPatternStart)}; + PatternKind::kSubstring, fixedPattern.size(), fixedPattern}; } + + auto fixedPattern = + unescape(pattern, fixedPatternStart, patternLength, escapeChar); + return PatternMetadata{ - PatternKind::kSuffix, patternLength - fixedPatternStart}; + PatternKind::kSuffix, fixedPattern.size(), fixedPattern}; } std::shared_ptr makeLike( @@ -1095,34 +1273,39 @@ std::shared_ptr makeLike( } auto pattern = constantPattern->as>()->valueAt(0); - if (!escapeChar) { - PatternMetadata patternMetadata = determinePatternKind(pattern); - PatternKind patternKind = patternMetadata.patternKind; - vector_size_t reducedLength = patternMetadata.length; - - switch (patternKind) { - case PatternKind::kExactlyN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kAtLeastN: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kFixed: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kPrefix: - return std::make_shared>( - pattern, reducedLength); - case PatternKind::kSuffix: - return std::make_shared>( - pattern, reducedLength); - default: - - return std::make_shared(pattern, escapeChar); - } + PatternMetadata patternMetadata; + try { + patternMetadata = determinePatternKind(pattern, escapeChar); + } catch (...) { + return std::make_shared( + std::current_exception()); } - return std::make_shared(pattern, escapeChar); + size_t reducedLength = patternMetadata.length; + auto fixedPattern = patternMetadata.fixedPattern; + + switch (patternMetadata.patternKind) { + case PatternKind::kExactlyN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kAtLeastN: + return std::make_shared>( + pattern, reducedLength); + case PatternKind::kFixed: + return std::make_shared>( + fixedPattern, reducedLength); + case PatternKind::kPrefix: + return std::make_shared>( + fixedPattern, reducedLength); + case PatternKind::kSuffix: + return std::make_shared>( + fixedPattern, reducedLength); + case PatternKind::kSubstring: + return std::make_shared>( + fixedPattern, reducedLength); + default: + return std::make_shared(pattern, escapeChar); + } } std::vector> likeSignatures() { diff --git a/velox/functions/lib/Re2Functions.h b/velox/functions/lib/Re2Functions.h index d45cd416709c..e2723dd6682c 100644 --- a/velox/functions/lib/Re2Functions.h +++ b/velox/functions/lib/Re2Functions.h @@ -48,11 +48,13 @@ enum class PatternKind { struct PatternMetadata { PatternKind patternKind; - // Contains the length of the fixed pattern for patterns of kind kFixed, - // kPrefix, and kSuffix. Contains the count of wildcard character '_' for - // patterns of kind kExactlyN and kAtLeastN. Contains 0 otherwise. - vector_size_t length; - // Contains the fixed pattern in patterns of kind kSubstring. + // Contains the length of the unescaped fixed pattern for patterns of kind + // kFixed, kPrefix, kSuffix and kSubstring. Contains the count of wildcard + // character '_' for patterns of kind kExactlyN and kAtLeastN. Contains 0 + // otherwise. + size_t length; + // Contains the unescaped fixed pattern in patterns of kind kFixed, kPrefix, + // kSuffix and kSubstring. std::string fixedPattern = ""; }; inline const int kMaxCompiledRegexes = 20; @@ -115,7 +117,9 @@ std::vector> re2ExtractSignatures(); /// prefix, and suffix patterns. Return the pair {pattern kind, number of '_' /// characters} for patterns with wildcard characters only. Return /// {kGenericPattern, 0} for generic patterns). -PatternMetadata determinePatternKind(StringView pattern); +PatternMetadata determinePatternKind( + StringView pattern, + std::optional escapeChar); std::shared_ptr makeLike( const std::string& name, diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index cad025448889..fc50fa176f97 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -461,13 +461,21 @@ TEST_F(Re2FunctionsTest, likePattern) { } TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { - auto testPattern = [&](StringView pattern, - PatternKind patternKind, - vector_size_t length, - StringView fixedPattern = "") { - PatternMetadata patternMetadata = determinePatternKind(pattern); + auto testPattern = + [&](StringView pattern, PatternKind patternKind, vector_size_t length) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, length); + }; + + auto testPatternString = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = + determinePatternKind(pattern, std::nullopt); EXPECT_EQ(patternMetadata.patternKind, patternKind); - EXPECT_EQ(patternMetadata.length, length); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); EXPECT_EQ(patternMetadata.fixedPattern, fixedPattern); }; @@ -477,12 +485,14 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%", PatternKind::kAtLeastN, 0); testPattern("__%%__", PatternKind::kAtLeastN, 4); testPattern("%_%%", PatternKind::kAtLeastN, 1); + testPattern("%%%%%%%%%%%%", PatternKind::kAtLeastN, 0); - testPattern("presto", PatternKind::kFixed, 6); - testPattern("hello", PatternKind::kFixed, 5); - testPattern("a", PatternKind::kFixed, 1); - testPattern("helloPrestoWorld", PatternKind::kFixed, 16); - testPattern("aBcD", PatternKind::kFixed, 4); + testPatternString("presto", PatternKind::kFixed, "presto"); + testPatternString("hello", PatternKind::kFixed, "hello"); + testPatternString("a", PatternKind::kFixed, "a"); + testPatternString( + "helloPrestoWorld", PatternKind::kFixed, "helloPrestoWorld"); + testPatternString("aBcD", PatternKind::kFixed, "aBcD"); testPattern("presto%", PatternKind::kPrefix, 6); testPattern("hello%%", PatternKind::kPrefix, 5); @@ -496,11 +506,11 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("%%%helloPrestoWorld", PatternKind::kSuffix, 16); testPattern("%aBcD", PatternKind::kSuffix, 4); - testPattern("%presto%%", PatternKind::kSubstring, 0, "presto"); - testPattern("%%hello%", PatternKind::kSubstring, 0, "hello"); - testPattern("%%%aAb\n%", PatternKind::kSubstring, 0, "aAb\n"); - testPattern( - "%helloPrestoWorld%%%", PatternKind::kSubstring, 0, "helloPrestoWorld"); + testPatternString("%presto%%", PatternKind::kSubstring, "presto"); + testPatternString("%%hello%", PatternKind::kSubstring, "hello"); + testPatternString("%%%aAb\n%", PatternKind::kSubstring, "aAb\n"); + testPatternString( + "%helloPrestoWorld%%%", PatternKind::kSubstring, "helloPrestoWorld"); testPattern("_b%%__", PatternKind::kGeneric, 0); testPattern("%_%p", PatternKind::kGeneric, 0); @@ -515,6 +525,36 @@ TEST_F(Re2FunctionsTest, likeDeterminePatternKind) { testPattern("_aBcD", PatternKind::kGeneric, 0); } +TEST_F(Re2FunctionsTest, likeDeterminePatternKindWithEscapeChar) { + auto testPattern = [&](StringView pattern, + PatternKind patternKind, + StringView fixedPattern) { + PatternMetadata patternMetadata = determinePatternKind(pattern, '\\'); + EXPECT_EQ(patternMetadata.patternKind, patternKind); + EXPECT_EQ(patternMetadata.length, fixedPattern.size()); + EXPECT_EQ(patternMetadata.fixedPattern, fixedPattern); + }; + + testPattern(R"(\_)", PatternKind::kFixed, "_"); + testPattern(R"(\_\_\_\_)", PatternKind::kFixed, "____"); + testPattern(R"(a\_\_b\_\_c)", PatternKind::kFixed, "a__b__c"); + + testPattern(R"(\%)", PatternKind::kFixed, "%"); + testPattern(R"(\%\%\%)", PatternKind::kFixed, "%%%"); + testPattern(R"(a\%b\%c\%d)", PatternKind::kFixed, "a%b%c%d"); + + testPattern(R"(\_\_%%)", PatternKind::kPrefix, "__"); + testPattern(R"(a\_b\_c%%)", PatternKind::kPrefix, "a_b_c"); + + testPattern(R"(%%\_\_)", PatternKind::kSuffix, "__"); + testPattern(R"(%%a\_b\_c)", PatternKind::kSuffix, "a_b_c"); + testPattern(R"(%\_\%)", PatternKind::kSuffix, "_%"); + + testPattern(R"(%\_%%)", PatternKind::kSubstring, "_"); + testPattern(R"(%\_\%%%)", PatternKind::kSubstring, "_%"); + testPattern(R"(%\_ab\%%%)", PatternKind::kSubstring, "_ab%"); +} + TEST_F(Re2FunctionsTest, likePatternWildcard) { testLike("", "", true); testLike("", "%", true); @@ -573,6 +613,15 @@ TEST_F(Re2FunctionsTest, likePatternFixed) { testLike("\nabcd\n", "\nabc\nd\n", false); testLike("\nab\tcd\b", "\nabcd\b", false); + // Test literal '_' & '%' in pattern. + testLike("a", R"(\_)", '\\', false); + testLike("_b", R"(\_b)", '\\', true); + testLike("abc_d", R"(abc\_d)", '\\', true); + + testLike("a", R"(\%)", '\\', false); + testLike("abc%d", R"(abc\%d)", '\\', true); + testLike("abc%d", R"(a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input, true); } @@ -605,6 +654,15 @@ TEST_F(Re2FunctionsTest, likePatternPrefix) { testLike("\nabc\nde\n", "ab\nc%", false); testLike("\nabc\nde\n", "abc%", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(\_%)", '\\', true); + testLike("_bcd", R"(\_b%)", '\\', true); + testLike("abc_defg", R"(abc\_d%)", '\\', true); + + testLike("%ab", R"(\%%)", '\\', true); + testLike("abc%defg", R"(abc\%d%)", '\\', true); + testLike("abc%defg", R"(a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 66); testLike(input, input + generateString(kAnyWildcardCharacter), true); } @@ -637,6 +695,15 @@ TEST_F(Re2FunctionsTest, likePatternSuffix) { testLike("\nabcde\n", "%d\n", false); testLike("\nabcde\n", "%e_\n", false); + // Test literal '_' & '%' in pattern. + testLike("_", R"(%\_)", '\\', true); + testLike("cd_b", R"(%\_b)", '\\', true); + testLike("efgabc_d", R"(%abc\_d)", '\\', true); + + testLike("ab%", R"(%\%)", '\\', true); + testLike("efgabc%d", R"(%abc\%d)", '\\', true); + testLike("abc%defg", R"(%a\%d)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike(input, generateString(kAnyWildcardCharacter) + input, true); } @@ -668,6 +735,14 @@ TEST_F(Re2FunctionsTest, likeSubstringPattern) { testLike("\nabcde\n", "%%d\n%", false); testLike("\nabcde\n", "%%e_\n%%", false); + // Test literal '_' & '%' in pattern. + testLike("cd_be", R"(%\_b%)", '\\', true); + testLike("efgabc_dhi", R"(%abc\_d%)", '\\', true); + + testLike("ab%cd", R"(%\%%)", '\\', true); + testLike("efgabc%dhi", R"(%abc\%d%)", '\\', true); + testLike("abc%defg", R"(%a\%d%)", '\\', false); + std::string input = generateString(kLikePatternCharacterSet, 65); testLike( input, @@ -1045,12 +1120,13 @@ TEST_F(Re2FunctionsTest, tryException) { // Make sure we do not compile more than kMaxCompiledRegexes. TEST_F(Re2FunctionsTest, likeRegexLimit) { - VectorPtr pattern = makeFlatVector(26); - VectorPtr input = makeFlatVector(26); + int count = 26; + VectorPtr pattern = makeFlatVector(count); + VectorPtr input = makeFlatVector(count); VectorPtr result; auto flatInput = input->asFlatVector(); - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { flatInput->set(i, ""); } @@ -1077,14 +1153,14 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { auto verifyNoRegexCompilationForPattern = [&](PatternKind patternKind) { // Over 20 all optimized, will pass. - for (int i = 0; i < 26; i++) { + for (int i = 0; i < count; i++) { std::string patternAtIdx = getPatternAtIdx(patternKind, i); flatPattern->set(i, StringView(patternAtIdx)); } result = evaluate("like(c0 , c1)", makeRowVector({input, pattern})); // Pattern '%%%', of type kAtleastN, matches with empty input. assertEqualVectors( - makeConstant((patternKind == PatternKind::kAtLeastN), 26), result); + makeConstant((patternKind == PatternKind::kAtLeastN), count), result); }; // Infer regex compilation does not happen for optimized patterns by verifying