diff --git a/velox/docs/functions/spark/regexp.rst b/velox/docs/functions/spark/regexp.rst index da64dcd431ac6..93389c6c42bce 100644 --- a/velox/docs/functions/spark/regexp.rst +++ b/velox/docs/functions/spark/regexp.rst @@ -38,7 +38,12 @@ See https://github.com/google/re2/wiki/Syntax for more information. .. spark:function:: regexp_replace(string, pattern, overwrite) -> varchar Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite``. If no match is found, the original string is returned as is. - There is a limit to the number of unique regexes to be compiled per function call, which is 20. + There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception. + + regexp_replace will throw an exception if ``string`` contains an invalid UTF-8 character, or if ``pattern`` does not conform to RE2 syntax: https://github.com/google/re2/wiki/Syntax. + + regexp_replace does not support character class union, intersection, or difference and will throw an exception if they are detected within the provided ``pattern``. + Parameters: @@ -57,8 +62,13 @@ See https://github.com/google/re2/wiki/Syntax for more information. .. spark:function:: regexp_replace(string, pattern, overwrite, position) -> varchar :noindex: - Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite`` starting from the specified ``position``. If the ``position`` is less than one, the function returns an error. If ``position`` is greater than the length of ``string``, the function returns the original ``string`` without any modifications. - There is a limit to the number of unique regexes to be compiled per function call, which is 20. + Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite`` starting from the specified ``position``. If no match is found, the original string is returned as is. If the ``position`` is less than one, the function throws an exception. If ``position`` is greater than the length of ``string``, the function returns the original ``string`` without any modifications. + There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception. + + regexp_replace will throw an exception if ``string`` contains an invalid UTF-8 character, if ``position`` is less than 1, or if ``pattern`` does not conform to RE2 syntax: https://github.com/google/re2/wiki/Syntax. + + regexp_replace does not support character class union, intersection, or difference and will throw an exception if they are detected within the provided ``pattern``. + This function is 1-indexed, meaning the position of the first character is 1. Parameters: @@ -74,6 +84,6 @@ See https://github.com/google/re2/wiki/Syntax for more information. SELECT regexp_replace('Hello, World!', 'l', 'L', 6); -- 'Hello, WorLd!' - SELECT regexp_replace('Hello, World!', 'l', 'L', -5); -- 'Hello, World!' + SELECT regexp_replace('Hello, World!', 'l', 'L', 5); -- 'Hello, World!' - SELECT regexp_replace('Hello, World!', 'l', 'L', 100); -- ERROR: Position exceeds string length. + SELECT regexp_replace('Hello, World!', 'l', 'L', 100); -- 'Hello, World!' diff --git a/velox/expression/tests/SparkExpressionFuzzerTest.cpp b/velox/expression/tests/SparkExpressionFuzzerTest.cpp index 89a42d5f4ba26..17d698c5cf6d8 100644 --- a/velox/expression/tests/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/tests/SparkExpressionFuzzerTest.cpp @@ -45,8 +45,10 @@ int main(int argc, char** argv) { // The following list are the Spark UDFs that hit issues // For rlike you need the following combo in the only list: // rlike, md5 and upper + // regexp_replace: https://github.com/facebookincubator/velox/issues/8438 std::unordered_set skipFunctions = { "regexp_extract", + "regexp_replace", "rlike", "chr", "replace", diff --git a/velox/functions/sparksql/RegexFunctions.cpp b/velox/functions/sparksql/RegexFunctions.cpp index 1bf7e1cd20355..a397383cbaf77 100644 --- a/velox/functions/sparksql/RegexFunctions.cpp +++ b/velox/functions/sparksql/RegexFunctions.cpp @@ -15,6 +15,7 @@ */ #include #include "velox/functions/lib/Re2Functions.h" +#include "velox/functions/lib/string/StringImpl.h" namespace facebook::velox::functions::sparksql { namespace { @@ -110,23 +111,14 @@ template struct RegexpReplaceFunction { VELOX_DEFINE_FUNCTION_TYPES(T); + static constexpr bool is_default_ascii_behavior = true; + void call( out_type& result, const arg_type& stringInput, const arg_type& pattern, const arg_type& replace) { - re2::RE2* patternRegex = getCachedRegex(pattern.str()); - re2::StringPiece replaceStringPiece = toStringPiece(replace); - - std::string string(stringInput.data(), stringInput.size()); - RE2::GlobalReplace(&string, *patternRegex, replaceStringPiece); - - if (string.size()) { - result.resize(string.size()); - std::memcpy(result.data(), string.data(), string.size()); - } else { - result.resize(0); - } + call(result, stringInput, pattern, replace, 1); } void call( @@ -135,66 +127,89 @@ struct RegexpReplaceFunction { const arg_type& pattern, const arg_type& replace, const arg_type& position) { - VELOX_USER_CHECK_GE(position, 1, "regexp_replace requires a position >= 1"); + if (performChecks(result, stringInput, pattern, replace, position - 1)) { + return; + } + size_t start = functions::stringImpl::cappedByteLength( + stringInput, position - 1); + if (start > stringInput.size() + 1) { + result = stringInput; + return; + } + performReplace(result, stringInput, pattern, replace, start); + } - re2::RE2* patternRegex = getCachedRegex(pattern.str()); - re2::StringPiece replaceStringPiece = toStringPiece(replace); - re2::StringPiece inputStringPiece = toStringPiece(stringInput); + void callAscii( + out_type& result, + const arg_type& stringInput, + const arg_type& pattern, + const arg_type& replace) { + callAscii(result, stringInput, pattern, replace, 1); + } - if (position > stringInput.size() + 1) { - result.resize(inputStringPiece.size()); - std::memcpy( - result.data(), inputStringPiece.data(), inputStringPiece.size()); + void callAscii( + out_type& result, + const arg_type& stringInput, + const arg_type& pattern, + const arg_type& replace, + const arg_type& position) { + if (performChecks(result, stringInput, pattern, replace, position - 1)) { return; } + performReplace(result, stringInput, pattern, replace, position - 1); + } - // Adjust the position for UTF-8 by counting the code points. - size_t utf8Position = 0; - size_t numCodePoints = 0; - while (numCodePoints < position - 1 && utf8Position <= stringInput.size()) { - int charLength = - utf8proc_char_length(inputStringPiece.data() + utf8Position); - VELOX_USER_CHECK_GT( - charLength, 0, "regexp_replace encountered invalid UTF-8 character"); - ++numCodePoints; - utf8Position += charLength; + private: + bool performChecks( + out_type& result, + const arg_type& stringInput, + const arg_type& pattern, + const arg_type& replace, + const arg_type& position) { + VELOX_USER_CHECK_GE( + position + 1, 1, "regexp_replace requires a position >= 1"); + if (position > stringInput.size()) { + result = stringInput; + return true; } - if (utf8Position > stringInput.size() + 1) { - result.resize(inputStringPiece.size()); - std::memcpy( - result.data(), inputStringPiece.data(), inputStringPiece.size()); - return; + + if (stringInput.size() == 0) { + if (pattern.size() == 0 && position == 1) { + result = replace; + return true; + } + if (pattern.size() > 0) { + result = stringInput; + return true; + } } + return false; + } - re2::StringPiece prefix(inputStringPiece.data(), utf8Position); - re2::StringPiece targetStringPiece( - inputStringPiece.data() + utf8Position, - inputStringPiece.size() - utf8Position); + void performReplace( + out_type& result, + const arg_type& stringInput, + const arg_type& pattern, + const arg_type& replace, + const arg_type& position) { + re2::RE2* patternRegex = getRegex(pattern.str()); + re2::StringPiece replaceStringPiece = toStringPiece(replace); + std::string prefix(stringInput.data(), position); std::string targetString( - targetStringPiece.data(), targetStringPiece.size()); - RE2::GlobalReplace(&targetString, *patternRegex, replaceStringPiece); + stringInput.data() + position, stringInput.size() - position); - if (targetString.size() || prefix.size()) { - result.resize(prefix.size() + targetString.size()); - std::memcpy(result.data(), prefix.data(), prefix.size()); - std::memcpy( - result.data() + prefix.size(), - targetString.data(), - targetString.size()); - } else { - result.resize(0); - } + RE2::GlobalReplace(&targetString, *patternRegex, replaceStringPiece); + result = prefix + targetString; } - private: - re2::RE2* getCachedRegex(const std::string& pattern) const { - auto it = patternCache_.find(pattern); - if (it != patternCache_.end()) { + re2::RE2* getRegex(const std::string& pattern) { + auto it = cache_.find(pattern); + if (it != cache_.end()) { return it->second.get(); } VELOX_USER_CHECK_LT( - patternCache_.size(), + cache_.size(), kMaxCompiledRegexes, "regexp_replace hit the maximum number of unique regexes: {}", kMaxCompiledRegexes); @@ -202,12 +217,11 @@ struct RegexpReplaceFunction { auto patternRegex = std::make_unique(pattern); auto* rawPatternRegex = patternRegex.get(); checkForBadPattern(*rawPatternRegex); - patternCache_.emplace(pattern, std::move(patternRegex)); + cache_.emplace(pattern, std::move(patternRegex)); return rawPatternRegex; } - mutable folly::F14FastMap> - patternCache_; + folly::F14FastMap> cache_; }; } // namespace @@ -238,14 +252,14 @@ std::shared_ptr makeRegexExtract( void registerRegexpReplace(const std::string& prefix) { registerFunction( - {prefix + "REGEXP_REPLACE"}); + {prefix + "regexp_replace"}); registerFunction< RegexpReplaceFunction, Varchar, Varchar, Varchar, Varchar, - int64_t>({prefix + "REGEXP_REPLACE"}); + int32_t>({prefix + "regexp_replace"}); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index a24c0bf215aee..4ee315dcf6280 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -125,6 +125,8 @@ void registerFunctions(const std::string& prefix) { // Register size functions registerSize(prefix + "size"); + registerRegexpReplace(prefix); + registerFunction( {prefix + "get_json_object"}); diff --git a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp index 236bfeee2b899..0f2cb23049d74 100644 --- a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp @@ -20,6 +20,7 @@ #include #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/sparksql/RegexFunctions.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" @@ -34,6 +35,10 @@ class RegexFunctionsTest : public test::SparkFunctionBaseTest { void SetUp() override { SparkFunctionBaseTest::SetUp(); registerRegexpReplace(""); + // For parsing literal integers as INTEGER, not BIGINT, + // required by regexp_replace because its position argument + // is INTEGER. + options_.parseIntegerAsBigint = false; } std::optional rlike( @@ -301,16 +306,17 @@ TEST_F(RegexFunctionsTest, regexpReplaceWithEmptyString) { } TEST_F(RegexFunctionsTest, regexBadJavaPattern) { - EXPECT_THROW(testRegexpReplace("[]", "[a[b]]", ""), VeloxUserError); - EXPECT_THROW(testRegexpReplace("[]", "[a&&[b]]", ""), VeloxUserError); - EXPECT_THROW(testRegexpReplace("[]", "[a&&[^b]]", ""), VeloxUserError); + VELOX_ASSERT_THROW( + testRegexpReplace("[]", "[a[b]]", ""), + "regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])"); + VELOX_ASSERT_THROW( + testRegexpReplace("[]", "[a&&[b]]", ""), + "regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])"); + VELOX_ASSERT_THROW( + testRegexpReplace("[]", "[a&&[^b]]", ""), + "regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])"); } -TEST_F(RegexFunctionsTest, regexpReplaceInvalidUTF8) { - EXPECT_THROW( - testRegexpReplace(std::string("\xA0") + "bcacbdefg", "", "", {2}), - VeloxUserError); -} TEST_F(RegexFunctionsTest, regexpReplacePosition) { std::string output1 = "abc"; @@ -325,11 +331,15 @@ TEST_F(RegexFunctionsTest, regexpReplacePosition) { } TEST_F(RegexFunctionsTest, regexpReplaceNegativePosition) { - EXPECT_THROW(testRegexpReplace("abc", "a", "", {-1}), VeloxUserError); + VELOX_ASSERT_THROW( + testRegexpReplace("abc", "a", "", {-1}), + "regexp_replace requires a position >= 1"); } TEST_F(RegexFunctionsTest, regexpReplaceZeroPosition) { - EXPECT_THROW(testRegexpReplace("abc", "a", "", {0}), VeloxUserError); + VELOX_ASSERT_THROW( + testRegexpReplace("abc", "a", "", {0}), + "regexp_replace requires a position >= 1"); } TEST_F(RegexFunctionsTest, regexpReplacePositionTooLarge) { @@ -543,8 +553,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheLimitTest) { "X" + std::to_string(i) + "-Y" + std::to_string(i)); } - EXPECT_THROW( - testingRegexpReplaceRows(strings, patterns, replaces), VeloxUserError); + VELOX_ASSERT_THROW( + testingRegexpReplaceRows(strings, patterns, replaces), + "regexp_replace hit the maximum number of unique regexes: 20"); } TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) { @@ -564,10 +575,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) { } auto result = - testingRegexpReplaceRows(strings, patterns, replaces, positions, 50000); - auto output = convertOutput(expectedOutputs, 50000); + testingRegexpReplaceRows(strings, patterns, replaces, positions, 3); + auto output = convertOutput(expectedOutputs, 3); assertEqualVectors(result, output); } - } // namespace } // namespace facebook::velox::functions::sparksql