From e6522a18ffb98469a2f29b577189ab4cd7c19d7f Mon Sep 17 00:00:00 2001 From: Sergey Pershin Date: Fri, 13 Dec 2024 21:26:43 -0800 Subject: [PATCH] feat: Add 'expression.max_compiled_regexes' Query Config (#11850) Summary: Adding 'expression.max_compiled_regexes' Query Config property, so it can be adjusted per query. Also increasing the dfault value to 100 from 20. Reviewed By: yuandagits Differential Revision: D67183811 --- velox/core/QueryConfig.h | 9 +++ velox/docs/configs.rst | 4 ++ velox/docs/functions/spark/regexp.rst | 6 +- velox/functions/lib/Re2Functions.cpp | 67 ++++++++++++------- velox/functions/lib/Re2Functions.h | 25 ++++++- .../functions/lib/tests/Re2FunctionsTest.cpp | 56 +++++++++------- .../prestosql/tests/RegexpReplaceTest.cpp | 24 +++++++ velox/functions/sparksql/RegexFunctions.cpp | 43 ++++++------ velox/functions/sparksql/Split.h | 19 ++++++ .../sparksql/tests/RegexFunctionsTest.cpp | 6 +- 10 files changed, 185 insertions(+), 74 deletions(-) diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index ef6982217270..e6c4032768b1 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -82,6 +82,11 @@ class QueryConfig { static constexpr const char* kExprMaxArraySizeInReduce = "expression.max_array_size_in_reduce"; + /// Controls maximum number of compiled regular expression patterns per + /// function instance per thread of execution. + static constexpr const char* kExprMaxCompiledRegexes = + "expression.max_compiled_regexes"; + /// Used for backpressure to block local exchange producers when the local /// exchange buffer reaches or exceeds this size. static constexpr const char* kMaxLocalExchangeBufferSize = @@ -617,6 +622,10 @@ class QueryConfig { return get(kExprMaxArraySizeInReduce, 100'000); } + uint64_t exprMaxCompiledRegexes() const { + return get(kExprMaxCompiledRegexes, 100); + } + bool adjustTimestampToTimezone() const { return get(kAdjustTimestampToTimezone, false); } diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index e7a669c89ae5..a932ecd57558 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -183,6 +183,10 @@ Expression Evaluation Configuration - integer - 100000 - ``Reduce`` function will throw an error if encountered an array of size greater than this. + * - expression.max_compiled_regexes + - integer + - 100 + - Controls maximum number of compiled regular expression patterns per batch. * - debug_disable_expression_with_peeling - bool - false diff --git a/velox/docs/functions/spark/regexp.rst b/velox/docs/functions/spark/regexp.rst index b39b0e4ae249..0552e2a4678f 100644 --- a/velox/docs/functions/spark/regexp.rst +++ b/velox/docs/functions/spark/regexp.rst @@ -26,9 +26,9 @@ See https://github.com/google/re2/wiki/Syntax for more information. Note: The wildcard '%' represents 0, 1 or multiple characters and the wildcard '_' represents exactly one character. - Note: Each function instance allow for a maximum of 20 regular expressions to - be compiled per thread of execution. Not all patterns require - compilation of regular expressions. Patterns 'hello', 'hello%', '_hello__%', + Note: Each function instance allow for a maximum of ``expression.max_compiled_regexes`` + (default 100) regular expressions to be compiled per thread of execution. Not all patterns + require compilation of regular expressions. Patterns 'hello', 'hello%', '_hello__%', '%hello', '%__hello_', '%hello%', where 'hello', 'velox' contains only regular characters and '_' wildcards are evaluated without using regular expressions. Only those patterns that require the compilation of diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index 44f8b66e31b9..ef1291fbe9d5 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -43,7 +43,7 @@ Expected ReCache::tryFindOrCompile(const StringView& pattern) { return reIt->second.get(); } - if (cache_.size() >= kMaxCompiledRegexes) { + if (cache_.size() >= maxCompiledRegexes_) { return folly::makeUnexpected( Status::UserError("Max number of regex reached")); } @@ -239,6 +239,8 @@ class Re2MatchConstantPattern final : public exec::VectorFunction { template class Re2Match final : public exec::VectorFunction { public: + explicit Re2Match(int64_t maxCompiledRegexes) : cache_(maxCompiledRegexes) {} + void apply( const SelectivityVector& rows, std::vector& args, @@ -359,8 +361,8 @@ class Re2SearchAndExtractConstantPattern final : public exec::VectorFunction { template class Re2SearchAndExtract final : public exec::VectorFunction { public: - explicit Re2SearchAndExtract(bool emptyNoMatch) - : emptyNoMatch_(emptyNoMatch) {} + explicit Re2SearchAndExtract(bool emptyNoMatch, int64_t maxCompiledRegexes) + : emptyNoMatch_(emptyNoMatch), cache_(maxCompiledRegexes) {} void apply( const SelectivityVector& rows, std::vector& args, @@ -886,11 +888,15 @@ class LikeWithRe2 final : public exec::VectorFunction { }; // This function is constructed when pattern or escape are not constants. -// It allows up to kMaxCompiledRegexes different regular expressions to be -// compiled throughout the query lifetime per expression and thread of -// execution, note that optimized regular expressions that are not compiled are -// not counted. +// It allows up to 'expression.max_compiled_regexes' different regular +// expressions to be compiled throughout the query lifetime per expression and +// thread of execution, note that optimized regular expressions that are not +// compiled are not counted. class LikeGeneric final : public exec::VectorFunction { + public: + explicit LikeGeneric(int64_t maxCompiledRegexes) + : maxCompiledRegexes_(maxCompiledRegexes) {} + void apply( const SelectivityVector& rows, std::vector& args, @@ -1008,7 +1014,7 @@ class LikeGeneric final : public exec::VectorFunction { VELOX_USER_CHECK_LT( compiledRegularExpressions_.size(), - kMaxCompiledRegexes, + maxCompiledRegexes_, "Max number of regex reached"); bool validEscapeUsage; @@ -1033,6 +1039,7 @@ class LikeGeneric final : public exec::VectorFunction { std::pair>, std::unique_ptr> compiledRegularExpressions_; + int64_t maxCompiledRegexes_; }; void re2ExtractAll( @@ -1145,6 +1152,9 @@ class Re2ExtractAllConstantPattern final : public exec::VectorFunction { template class Re2ExtractAll final : public exec::VectorFunction { public: + explicit Re2ExtractAll(int64_t maxCompiledRegexes) + : cache_(maxCompiledRegexes) {} + void apply( const SelectivityVector& rows, std::vector& args, @@ -1204,7 +1214,8 @@ class Re2ExtractAll final : public exec::VectorFunction { template std::shared_ptr makeRe2MatchImpl( const std::string& name, - const std::vector& inputArgs) { + const std::vector& inputArgs, + const core::QueryConfig& config) { if (inputArgs.size() != 2 || !inputArgs[0].type->isVarchar() || !inputArgs[1].type->isVarchar()) { VELOX_UNSUPPORTED( @@ -1220,11 +1231,14 @@ std::shared_ptr makeRe2MatchImpl( constantPattern->as>()->valueAt(0)); } - return std::make_shared>(); + return std::make_shared>(config.exprMaxCompiledRegexes()); } class RegexpReplaceWithLambdaFunction : public exec::VectorFunction { public: + explicit RegexpReplaceWithLambdaFunction(int64_t maxCompiledRegexes) + : cache_(maxCompiledRegexes) {} + void apply( const SelectivityVector& rows, std::vector& args, @@ -1592,8 +1606,8 @@ class RegexpReplaceWithLambdaFunction : public exec::VectorFunction { std::shared_ptr makeRe2Match( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { - return makeRe2MatchImpl(name, inputArgs); + const core::QueryConfig& config) { + return makeRe2MatchImpl(name, inputArgs, config); } std::vector> re2MatchSignatures() { @@ -1608,8 +1622,8 @@ std::vector> re2MatchSignatures() { std::shared_ptr makeRe2Search( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { - return makeRe2MatchImpl(name, inputArgs); + const core::QueryConfig& config) { + return makeRe2MatchImpl(name, inputArgs, config); } std::vector> re2SearchSignatures() { @@ -1624,7 +1638,7 @@ std::vector> re2SearchSignatures() { std::shared_ptr makeRe2Extract( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/, + const core::QueryConfig& config, const bool emptyNoMatch) { auto numArgs = inputArgs.size(); VELOX_USER_CHECK( @@ -1673,11 +1687,14 @@ std::shared_ptr makeRe2Extract( } } + const auto maxCompiledRegexes = config.exprMaxCompiledRegexes(); switch (groupIdTypeKind) { case TypeKind::INTEGER: - return std::make_shared>(emptyNoMatch); + return std::make_shared>( + emptyNoMatch, maxCompiledRegexes); case TypeKind::BIGINT: - return std::make_shared>(emptyNoMatch); + return std::make_shared>( + emptyNoMatch, maxCompiledRegexes); default: VELOX_UNREACHABLE(); } @@ -2158,14 +2175,14 @@ PatternMetadata determinePatternKind( std::shared_ptr makeLike( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { + const core::QueryConfig& config) { auto numArgs = inputArgs.size(); std::optional escapeChar; if (numArgs == 3) { BaseVector* escape = inputArgs[2].constantValue.get(); if (!escape) { - return std::make_shared(); + return std::make_shared(config.exprMaxCompiledRegexes()); } auto constantEscape = escape->as>(); @@ -2191,7 +2208,7 @@ std::shared_ptr makeLike( BaseVector* constantPattern = inputArgs[1].constantValue.get(); if (!constantPattern) { - return std::make_shared(); + return std::make_shared(config.exprMaxCompiledRegexes()); } if (constantPattern->isNullAt(0)) { @@ -2273,7 +2290,7 @@ std::vector> likeSignatures() { std::shared_ptr makeRe2ExtractAll( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { + const core::QueryConfig& config) { auto numArgs = inputArgs.size(); VELOX_USER_CHECK( numArgs == 2 || numArgs == 3, @@ -2318,11 +2335,12 @@ std::shared_ptr makeRe2ExtractAll( } } + const auto maxCompiledRegexes = config.exprMaxCompiledRegexes(); switch (groupIdTypeKind) { case TypeKind::INTEGER: - return std::make_shared>(); + return std::make_shared>(maxCompiledRegexes); case TypeKind::BIGINT: - return std::make_shared>(); + return std::make_shared>(maxCompiledRegexes); default: VELOX_UNREACHABLE(); } @@ -2357,7 +2375,8 @@ std::shared_ptr makeRegexpReplaceWithLambda( const std::string& name, const std::vector& inputArgs, const core::QueryConfig& config) { - return std::make_shared(); + return std::make_shared( + config.exprMaxCompiledRegexes()); } std::vector> diff --git a/velox/functions/lib/Re2Functions.h b/velox/functions/lib/Re2Functions.h index 0c03d0c58f0b..ef6712a73401 100644 --- a/velox/functions/lib/Re2Functions.h +++ b/velox/functions/lib/Re2Functions.h @@ -157,8 +157,6 @@ class PatternMetadata { std::vector substrings_; }; -inline const int kMaxCompiledRegexes = 20; - /// The functions in this file use RE2 as the regex engine. RE2 is fast, but /// supports only a subset of PCRE syntax and in particular does not support /// backtracking and associated features (e.g. backreferences). @@ -255,18 +253,26 @@ std::vector> re2ExtractAllSignatures(); namespace detail { // A cache of compiled regular expressions (RE2 instances). Allows up to -// 'kMaxCompiledRegexes' different expressions. +// 'expression.max_compiled_regexes' different expressions. // // Compiling regular expressions is expensive. It can take up to 200 times // more CPU time to compile a regex vs. evaluate it. class ReCache { public: + explicit ReCache(uint64_t maxCompiledRegexes) + : maxCompiledRegexes_(maxCompiledRegexes) {} + + void setMaxCompiledRegexes(uint64_t maxCompiledRegexes) { + maxCompiledRegexes_ = maxCompiledRegexes; + } + RE2* findOrCompile(const StringView& pattern); Expected tryFindOrCompile(const StringView& pattern); private: folly::F14FastMap> cache_; + uint64_t maxCompiledRegexes_; }; } // namespace detail @@ -287,6 +293,8 @@ template < std::string (*prepareRegexpPattern)(const StringView&), std::string (*prepareRegexpReplacement)(const RE2&, const StringView&)> struct Re2RegexpReplace { + Re2RegexpReplace() : cache_(0) {} + VELOX_DEFINE_FUNCTION_TYPES(T); FOLLY_ALWAYS_INLINE void initialize( @@ -304,6 +312,7 @@ struct Re2RegexpReplace { processedPattern, re_->error()); } + cache_.setMaxCompiledRegexes(config.exprMaxCompiledRegexes()); if (replacement != nullptr) { // Constant 'replacement' with non-constant 'pattern' needs to be @@ -377,8 +386,18 @@ struct Re2RegexpReplace { template struct Re2RegexpSplit { + Re2RegexpSplit() : cache_(0) {} + VELOX_DEFINE_FUNCTION_TYPES(TExec); + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& config, + const arg_type* /*string*/, + const arg_type* /*pattern*/) { + cache_.setMaxCompiledRegexes(config.exprMaxCompiledRegexes()); + } + static constexpr int32_t reuse_strings_from_arg = 0; void call( diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 3197c750a067..58fd48ea5de9 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -1312,15 +1312,18 @@ TEST_F(Re2FunctionsTest, tryException) { } } -// Make sure we do not compile more than kMaxCompiledRegexes. +// Make sure we do not compile more than "expression.max_compiled_regexes". TEST_F(Re2FunctionsTest, likeRegexLimit) { - int count = 26; - VectorPtr pattern = makeFlatVector(count); - VectorPtr input = makeFlatVector(count); + const auto maxCompiledRegexes = + core::QueryConfig({}).exprMaxCompiledRegexes(); + const auto aboveMaxCompiledRegexes = maxCompiledRegexes + 5; + + VectorPtr pattern = makeFlatVector(aboveMaxCompiledRegexes); + VectorPtr input = makeFlatVector(aboveMaxCompiledRegexes); VectorPtr result; auto flatInput = input->asFlatVector(); - for (int i = 0; i < count; i++) { + for (auto i = 0; i < aboveMaxCompiledRegexes; i++) { flatInput->set(i, ""); } @@ -1346,19 +1349,21 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { }; auto verifyNoRegexCompilationForPattern = [&](PatternKind patternKind) { - // Over 20 all optimized, will pass. - for (int i = 0; i < count; i++) { + // Over maxCompiledRegexes all optimized, will pass. + for (auto i = 0; i < aboveMaxCompiledRegexes; i++) { std::string patternAtIdx = getPatternAtIdx(patternKind, i); flatPattern->set(i, StringView(patternAtIdx)); } result = evaluate("like(c0 , c1)", makeRowVector({input, pattern})); // Pattern '%%%', of type kAtleastN, matches with empty input. assertEqualVectors( - makeConstant((patternKind == PatternKind::kAtLeastN), count), result); + makeConstant( + (patternKind == PatternKind::kAtLeastN), aboveMaxCompiledRegexes), + result); }; // Infer regex compilation does not happen for optimized patterns by verifying - // less than kMaxCompiledRegexes are compiled for each optimized pattern type. + // less than maxCompiledRegexes are compiled for each optimized pattern type. verifyNoRegexCompilationForPattern(PatternKind::kExactlyN); verifyNoRegexCompilationForPattern(PatternKind::kAtLeastN); verifyNoRegexCompilationForPattern(PatternKind::kFixed); @@ -1366,8 +1371,8 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { verifyNoRegexCompilationForPattern(PatternKind::kSuffix); verifyNoRegexCompilationForPattern(PatternKind::kSubstring); - // Over 20, all require regex, will fail. - for (int i = 0; i < 26; i++) { + // Over maxCompiledRegexes, all require regex, will fail. + for (auto i = 0; i < aboveMaxCompiledRegexes; i++) { std::string localPattern = fmt::format("b%[0-9]+.*{}.*{}.*[0-9]+", 'c' + i, 'c' + i); flatPattern->set(i, StringView(localPattern)); @@ -1377,21 +1382,21 @@ TEST_F(Re2FunctionsTest, likeRegexLimit) { evaluate("like(c0, c1)", makeRowVector({input, pattern})), "Max number of regex reached"); - // First 20 rows should return false, the rest raise and error and become - // null. + // First maxCompiledRegexes rows should return false, the rest raise and error + // and become null. result = evaluate("try(like(c0, c1))", makeRowVector({input, pattern})); auto expected = makeFlatVector( - 26, + aboveMaxCompiledRegexes, [](auto /*row*/) { return false; }, - [](auto row) { return row >= 20; }); + [&](auto row) { return row >= maxCompiledRegexes; }); assertEqualVectors(expected, result); // All are complex but the same, should pass. - for (int i = 0; i < 26; i++) { + for (auto i = 0; i < aboveMaxCompiledRegexes; i++) { flatPattern->set(i, "b%[0-9]+.*{}.*{}.*[0-9]+"); } result = evaluate("like(c0, c1)", makeRowVector({input, pattern})); - assertEqualVectors(makeConstant(false, 26), result); + assertEqualVectors(makeConstant(false, aboveMaxCompiledRegexes), result); } TEST_F(Re2FunctionsTest, invalidEscapeChar) { @@ -1435,19 +1440,24 @@ TEST_F(Re2FunctionsTest, regexExtractAllLarge) { "No group 4611686018427387904 in regex '(\\d+)([a-z]+)") } -// Make sure we do not compile more than kMaxCompiledRegexes. +// Make sure we do not compile more than "expression.max_compiled_regexes". TEST_F(Re2FunctionsTest, limit) { + const auto maxCompiledRegexes = + core::QueryConfig({}).exprMaxCompiledRegexes(); + const auto aboveMaxCompiledRegexes = maxCompiledRegexes + 5; + auto data = makeRowVector({ makeFlatVector( - 100, + aboveMaxCompiledRegexes, [](auto row) { return fmt::format("Apples and oranges {}", row); }), makeFlatVector( - 100, + aboveMaxCompiledRegexes, [](auto row) { return fmt::format("Apples (.*) oranges {}", row); }), makeFlatVector( - 100, - [](auto row) { - return fmt::format("Apples (.*) oranges {}", row % 20); + aboveMaxCompiledRegexes, + [&](auto row) { + return fmt::format( + "Apples (.*) oranges {}", row % maxCompiledRegexes); }), }); diff --git a/velox/functions/prestosql/tests/RegexpReplaceTest.cpp b/velox/functions/prestosql/tests/RegexpReplaceTest.cpp index ef117d06335b..008bb16e7305 100644 --- a/velox/functions/prestosql/tests/RegexpReplaceTest.cpp +++ b/velox/functions/prestosql/tests/RegexpReplaceTest.cpp @@ -270,5 +270,29 @@ TEST_F(RegexpReplaceTest, lambda) { test::assertEqualVectors(expected, result); } +// Make sure we do not compile more than "expression.max_compiled_regexes". +TEST_F(RegexpReplaceTest, limit) { + const auto maxCompiledRegexes = + core::QueryConfig({}).exprMaxCompiledRegexes(); + const auto aboveMaxCompiledRegexes = maxCompiledRegexes + 5; + + auto data = makeRowVector( + {makeFlatVector( + aboveMaxCompiledRegexes, + [](auto row) { return fmt::format("Apples and oranges {}", row); }), + makeFlatVector( + aboveMaxCompiledRegexes, + [](auto row) { return fmt::format("\\d+[ab]{}", row); }), + makeFlatVector(aboveMaxCompiledRegexes, [&](auto row) { + return fmt::format("Apples (.*) oranges {}", row % maxCompiledRegexes); + })}); + + VELOX_ASSERT_THROW( + evaluate("regexp_replace(c0, c1, c2)", data), + "Max number of regex reached"); + VELOX_ASSERT_THROW( + evaluate("regexp_replace(c0, c1)", data), "Max number of regex reached"); +} + } // namespace } // namespace facebook::velox diff --git a/velox/functions/sparksql/RegexFunctions.cpp b/velox/functions/sparksql/RegexFunctions.cpp index 4b8fd2aa607b..72e00e1492a1 100644 --- a/velox/functions/sparksql/RegexFunctions.cpp +++ b/velox/functions/sparksql/RegexFunctions.cpp @@ -42,6 +42,8 @@ void ensureRegexIsConstant( // If position > length string, return string. template struct RegexpReplaceFunction { + RegexpReplaceFunction() : cache_(0) {} + VELOX_DEFINE_FUNCTION_TYPES(T); static constexpr bool is_default_ascii_behavior = true; @@ -49,19 +51,19 @@ struct RegexpReplaceFunction { FOLLY_ALWAYS_INLINE void initialize( const std::vector& inputTypes, const core::QueryConfig& config, - const arg_type* str, + const arg_type* stringInput, const arg_type* pattern, const arg_type* replacement) { - initialize(inputTypes, config, str, pattern, replacement, nullptr); + initialize(inputTypes, config, stringInput, pattern, replacement, nullptr); } FOLLY_ALWAYS_INLINE void initialize( const std::vector& /*inputTypes*/, - const core::QueryConfig& /*config*/, - const arg_type* /*string*/, + const core::QueryConfig& config, + const arg_type* /*stringInput*/, const arg_type* pattern, const arg_type* replacement, - const arg_type* /*position*/) { + const arg_type* /*position*/) { if (pattern) { const auto processedPattern = prepareRegexpReplacePattern(*pattern); re_.emplace(processedPattern, RE2::Quiet); @@ -79,23 +81,25 @@ struct RegexpReplaceFunction { prepareRegexpReplaceReplacement(re_.value(), *replacement); } } + cache_.setMaxCompiledRegexes(config.exprMaxCompiledRegexes()); } void call( out_type& result, const arg_type& stringInput, const arg_type& pattern, - const arg_type& replace) { - call(result, stringInput, pattern, replace, 1); + const arg_type& replacement) { + call(result, stringInput, pattern, replacement, 1); } void call( out_type& result, const arg_type& stringInput, const arg_type& pattern, - const arg_type& replace, - const arg_type& position) { - if (performChecks(result, stringInput, pattern, replace, position - 1)) { + const arg_type& replacement, + const arg_type& position) { + if (performChecks( + result, stringInput, pattern, replacement, position - 1)) { return; } size_t start = functions::stringImpl::cappedByteLength( @@ -104,27 +108,28 @@ struct RegexpReplaceFunction { result = stringInput; return; } - performReplace(result, stringInput, pattern, replace, start); + performReplace(result, stringInput, pattern, replacement, start); } void callAscii( out_type& result, const arg_type& stringInput, const arg_type& pattern, - const arg_type& replace) { - callAscii(result, stringInput, pattern, replace, 1); + const arg_type& replacement) { + callAscii(result, stringInput, pattern, replacement, 1); } void callAscii( out_type& result, const arg_type& stringInput, const arg_type& pattern, - const arg_type& replace, - const arg_type& position) { - if (performChecks(result, stringInput, pattern, replace, position - 1)) { + const arg_type& replacement, + const arg_type& position) { + if (performChecks( + result, stringInput, pattern, replacement, position - 1)) { return; } - performReplace(result, stringInput, pattern, replace, position - 1); + performReplace(result, stringInput, pattern, replacement, position - 1); } private: @@ -133,7 +138,7 @@ struct RegexpReplaceFunction { const arg_type& stringInput, const arg_type& pattern, const arg_type& replace, - const arg_type& position) { + const arg_type& position) { VELOX_USER_CHECK_GE( position + 1, 1, "regexp_replace requires a position >= 1"); if (position > stringInput.size()) { @@ -153,7 +158,7 @@ struct RegexpReplaceFunction { const arg_type& stringInput, const arg_type& pattern, const arg_type& replace, - const arg_type& position) { + const arg_type& position) { auto& re = ensurePattern(pattern); const auto& processedReplacement = constantReplacement_.has_value() ? constantReplacement_.value() diff --git a/velox/functions/sparksql/Split.h b/velox/functions/sparksql/Split.h index cb8fcb076700..e9666b6448da 100644 --- a/velox/functions/sparksql/Split.h +++ b/velox/functions/sparksql/Split.h @@ -28,11 +28,30 @@ namespace facebook::velox::functions::sparksql { /// applied as many times as possible. template struct Split { + Split() : cache_(0) {} + VELOX_DEFINE_FUNCTION_TYPES(T); // Results refer to strings in the first argument. static constexpr int32_t reuse_strings_from_arg = 0; + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& inputTypes, + const core::QueryConfig& config, + const arg_type* input, + const arg_type* delimiter) { + initialize(inputTypes, config, input, delimiter, nullptr); + } + + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& /*inputTypes*/, + const core::QueryConfig& config, + const arg_type* /*input*/, + const arg_type* /*delimiter*/, + const arg_type* /*limit*/) { + cache_.setMaxCompiledRegexes(config.exprMaxCompiledRegexes()); + } + FOLLY_ALWAYS_INLINE void call( out_type>& result, const arg_type& input, diff --git a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp index 54ced11f3bad..412ef520792c 100644 --- a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp @@ -551,8 +551,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheLimitTest) { std::vector strings; std::vector replaces; std::vector expectedOutputs; + const core::QueryConfig config({}); - for (int i = 0; i <= kMaxCompiledRegexes; ++i) { + for (int i = 0; i <= config.exprMaxCompiledRegexes(); ++i) { patterns.push_back("\\d" + std::to_string(i) + "-\\d" + std::to_string(i)); strings.push_back("1" + std::to_string(i) + "-2" + std::to_string(i)); replaces.push_back("X" + std::to_string(i) + "-Y" + std::to_string(i)); @@ -571,8 +572,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) { std::vector replaces; std::vector expectedOutputs; std::vector positions; + const core::QueryConfig config({}); - for (int i = 0; i <= kMaxCompiledRegexes - 1; ++i) { + for (int i = 0; i <= config.exprMaxCompiledRegexes() - 1; ++i) { patterns.push_back("\\d" + std::to_string(i) + "-\\d" + std::to_string(i)); strings.push_back("1" + std::to_string(i) + "-2" + std::to_string(i)); replaces.push_back("X" + std::to_string(i) + "-Y" + std::to_string(i));