From c286451a4ec4198dbf16a49fcfa1ae8d53a210a5 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Wed, 20 Nov 2024 11:07:38 -0800 Subject: [PATCH] feat(function): Add Spark locate function (#8863) Summary: A function that returns the position of the first occurrence of substring in given string after the start position. Doc: https://spark.apache.org/docs/latest/api/sql/index.html#locate Spark implementation: https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L1420 Pull Request resolved: https://github.com/facebookincubator/velox/pull/8863 Differential Revision: D66203871 Pulled By: kagamiori fbshipit-source-id: cf117699a795a19786bc1df546a1578daa9757b9 --- velox/docs/functions/spark/string.rst | 29 +++++ velox/functions/lib/string/StringImpl.h | 28 +++-- .../lib/string/tests/StringImplTest.cpp | 56 +++++----- velox/functions/prestosql/StringFunctions.h | 8 +- velox/functions/sparksql/Register.cpp | 4 +- velox/functions/sparksql/String.h | 101 +++++++++++++++++- velox/functions/sparksql/tests/StringTest.cpp | 39 +++++++ 7 files changed, 213 insertions(+), 52 deletions(-) diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index d241714601e6..1bb01c65eae6 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -112,6 +112,35 @@ String Functions SELECT levenshtein('kitten', 'sitting', 10); -- 3 SELECT levenshtein('kitten', 'sitting', 2); -- -1 +.. spark:function:: locate(substring, string, start) -> integer + + Returns the 1-based position of the first occurrence of ``substring`` in given ``string`` + after position ``start``. The search is from the beginning of ``string`` to the end. + ``start`` is the starting character position in ``string`` to search for the ``substring``. + ``start`` is 1-based and must be at least 1 and at most the characters number of ``string``. + The following rules on special values are applied to follow Spark's implementation. + They are listed in order of priority: + + Returns 0 if ``start`` is NULL. Returns NULL if ``substring`` or ``string`` is NULL. + Returns 0 if ``start`` is less than 1. + Returns 1 if ``substring`` is empty. + Returns 0 if ``start`` is greater than the characters number of ``string``. + Returns 0 if ``substring`` is not found in ``string``. :: + + SELECT locate('aa', 'aaads', 1); -- 1 + SELECT locate('aa', 'aaads', -1); -- 0 + SELECT locate('aa', 'aaads', 2); -- 2 + SELECT locate('aa', 'aaads', 6); -- 0 + SELECT locate('aa', 'aaads', NULL); -- 0 + SELECT locate('', 'aaads', 1); -- 1 + SELECT locate('', 'aaads', 9); -- 1 + SELECT locate('', 'aaads', -1); -- 0 + SELECT locate('', '', 1); -- 1 + SELECT locate('aa', '', 1); -- 0 + SELECT locate(NULL, NULL, NULL); -- 0 + SELECT locate(NULL, NULL, 1); -- NULL + SELECT locate('\u4FE1', '\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B', 2); -- 4 + .. spark:function:: lower(string) -> string Returns string with all characters changed to lowercase. :: diff --git a/velox/functions/lib/string/StringImpl.h b/velox/functions/lib/string/StringImpl.h index 9671d9434662..e0a13ced7e5d 100644 --- a/velox/functions/lib/string/StringImpl.h +++ b/velox/functions/lib/string/StringImpl.h @@ -197,13 +197,17 @@ std::vector stringToCodePoints(const T& inputString) { return codePoints; } -/// Returns the starting position in characters of the Nth instance(counting -/// from the left if lpos==true and from the end otherwise) of the substring in -/// string. Positions start with 1. If not found, 0 is returned. If subString is -/// empty result is 1. -template -FOLLY_ALWAYS_INLINE int64_t -stringPosition(const T& string, const T& subString, int64_t instance = 0) { +/// Returns the starting position in characters of the Nth instance of the +/// substring in string. Positions start with 1. If not found, 0 is returned. If +/// subString is empty result is 1. +/// @tparam lpos If true, counting from the start of the string. Counting from +/// the end of the string otherwise. +/// @param instance The 1-based instance of the substring to find in string. +template +FOLLY_ALWAYS_INLINE int64_t stringPosition( + std::string_view string, + std::string_view subString, + int64_t instance) { VELOX_USER_CHECK_GT(instance, 0, "'instance' must be a positive number"); if (subString.size() == 0) { return 1; @@ -211,15 +215,9 @@ stringPosition(const T& string, const T& subString, int64_t instance = 0) { int64_t byteIndex = -1; if constexpr (lpos) { - byteIndex = findNthInstanceByteIndexFromStart( - std::string_view(string.data(), string.size()), - std::string_view(subString.data(), subString.size()), - instance); + byteIndex = findNthInstanceByteIndexFromStart(string, subString, instance); } else { - byteIndex = findNthInstanceByteIndexFromEnd( - std::string_view(string.data(), string.size()), - std::string_view(subString.data(), subString.size()), - instance); + byteIndex = findNthInstanceByteIndexFromEnd(string, subString, instance); } if (byteIndex == -1) { diff --git a/velox/functions/lib/string/tests/StringImplTest.cpp b/velox/functions/lib/string/tests/StringImplTest.cpp index cd0479e370d1..545d457bcc76 100644 --- a/velox/functions/lib/string/tests/StringImplTest.cpp +++ b/velox/functions/lib/string/tests/StringImplTest.cpp @@ -396,38 +396,38 @@ TEST_F(StringImplTest, stringToCodePoints) { } TEST_F(StringImplTest, overlappedStringPosition) { - auto testValidInputAsciiLpos = [](const std::string& string, - const std::string& substr, + auto testValidInputAsciiLpos = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { - auto result = stringPosition( - StringView(string), StringView(substr), instance); + auto result = + stringPosition(string, substr, instance); ASSERT_EQ(result, expectedPosition); }; - auto testValidInputAsciiRpos = [](const std::string& string, - const std::string& substr, + auto testValidInputAsciiRpos = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { - auto result = stringPosition( - StringView(string), StringView(substr), instance); + auto result = + stringPosition(string, substr, instance); ASSERT_EQ(result, expectedPosition); }; - auto testValidInputUnicodeLpos = [](const std::string& string, - const std::string& substr, + auto testValidInputUnicodeLpos = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { - auto result = stringPosition( - StringView(string), StringView(substr), instance); + auto result = + stringPosition(string, substr, instance); ASSERT_EQ(result, expectedPosition); }; - auto testValidInputUnicodeRpos = [](const std::string& string, - const std::string& substr, + auto testValidInputUnicodeRpos = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { - auto result = stringPosition( - StringView(string), StringView(substr), instance); + auto result = + stringPosition(string, substr, instance); ASSERT_EQ(result, expectedPosition); }; @@ -445,31 +445,27 @@ TEST_F(StringImplTest, overlappedStringPosition) { } TEST_F(StringImplTest, stringPosition) { - auto testValidInputAscii = [](const std::string& string, - const std::string& substr, + auto testValidInputAscii = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { ASSERT_EQ( - stringPosition( - StringView(string), StringView(substr), instance), + stringPosition(string, substr, instance), expectedPosition); ASSERT_EQ( - stringPosition( - StringView(string), StringView(substr), instance), + stringPosition(string, substr, instance), expectedPosition); }; - auto testValidInputUnicode = [](const std::string& string, - const std::string& substr, + auto testValidInputUnicode = [](std::string_view string, + std::string_view substr, const int64_t instance, const int64_t expectedPosition) { ASSERT_EQ( - stringPosition( - StringView(string), StringView(substr), instance), + stringPosition(string, substr, instance), expectedPosition); ASSERT_EQ( - stringPosition( - StringView(string), StringView(substr), instance), + stringPosition(string, substr, instance), expectedPosition); }; @@ -494,9 +490,7 @@ TEST_F(StringImplTest, stringPosition) { testValidInputUnicode("abc/xyz/foo/bar", "/", 4, 0L); EXPECT_THROW( - stringPosition( - StringView("foobar"), StringView("foobar"), 0), - VeloxUserError); + stringPosition("foobar", "foobar", 0), VeloxUserError); } TEST_F(StringImplTest, replaceFirst) { diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index 34b67dbcc77b..c09a996b9317 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -412,7 +412,9 @@ struct StrPosFunctionBase { const arg_type& subString, const arg_type& instance = 1) { result = stringImpl::stringPosition( - string, subString, instance); + std::string_view(string.data(), string.size()), + std::string_view(subString.data(), subString.size()), + instance); } FOLLY_ALWAYS_INLINE void callAscii( @@ -421,7 +423,9 @@ struct StrPosFunctionBase { const arg_type& subString, const arg_type& instance = 1) { result = stringImpl::stringPosition( - string, subString, instance); + std::string_view(string.data(), string.size()), + std::string_view(subString.data(), subString.size()), + instance); } }; diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 00c88ab4ad19..c753eca82baa 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -303,13 +303,15 @@ void registerFunctions(const std::string& prefix) { registerCompareFunctions(prefix); registerBitwiseFunctions(prefix); - // String sreach function + // String search function registerFunction( {prefix + "startswith"}); registerFunction( {prefix + "endswith"}); registerFunction( {prefix + "contains"}); + registerFunction( + {prefix + "locate"}); registerFunction({prefix + "trim"}); registerFunction({prefix + "trim"}); diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index 72f106250c8b..2f188edc807c 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -262,7 +262,6 @@ struct StartsWithFunction { result = false; } else { result = str1.substr(0, str2.length()) == str2; - ; } return true; } @@ -293,6 +292,96 @@ struct EndsWithFunction { } }; +/// locate function +/// locate(substring, string, start) -> integer +/// +/// Returns the 1-based position of the first occurrence of 'substring' in +/// 'string' after the give 'start' position. The search is from the beginning +/// of 'string' to the end. 'start' is the starting character position in +/// 'string' to search for the 'substring'. 'start' is 1-based and must be at +/// least 1 and at most the characters number of 'string'. +/// +/// The following rules on special values are applied to follow Spark's +/// implementation. They are listed in order of priority: +/// Returns 0 if 'start' is NULL. Returns NULL if 'substring' or 'string' is +/// NULL. Returns 0 if 'start' is less than 1. Returns 1 if 'substring' is +/// empty. Returns 0 if 'start' is greater than the characters number of +/// 'string'. Returns 0 if 'substring' is not found in 'string'. +template +struct LocateFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& subString, + const arg_type& string, + const arg_type& start) { + if (start < 1) { + result = 0; + } else if (subString.empty()) { + result = 1; + } else if (start > string.size()) { + result = 0; + } else { + const auto position = stringImpl::stringPosition( + std::string_view( + string.data() + start - 1, string.size() - start + 1), + std::string_view(subString.data(), subString.size()), + 1 /*instance*/); + if (position) { + result = position + start - 1; + } else { + result = 0; + } + } + } + + FOLLY_ALWAYS_INLINE bool callNullable( + out_type& result, + const arg_type* subString, + const arg_type* string, + const arg_type* start) { + if (start == nullptr) { + result = 0; + return true; + } + if (subString == nullptr || string == nullptr) { + return false; + } + if (*start < 1) { + result = 0; + return true; + } + if (subString->empty()) { + result = 1; + return true; + } + if (*start > stringImpl::length(*string)) { + result = 0; + return true; + } + + // Find the start byte index of the start character. For example, in the + // Unicode string "😋😋😋", each character occupies 4 bytes. When 'start' is + // 2, the 'startByteIndex' is 4 which specifies the start of the second + // character. + const auto startByteIndex = stringCore::cappedByteLengthUnicode( + string->data(), string->size(), *start - 1); + + const auto position = stringImpl::stringPosition( + std::string_view( + string->data() + startByteIndex, string->size() - startByteIndex), + std::string_view(subString->data(), subString->size()), + 1 /*instance*/); + if (position) { + result = position + *start - 1; + } else { + result = 0; + } + return true; + } +}; + /// Returns the substring from str before count occurrences of the delimiter /// delim. If count is positive, everything to the left of the final delimiter /// (counting from the left) is returned. If count is negative, everything to @@ -321,9 +410,15 @@ struct SubstringIndexFunction { int64_t index; if (count > 0) { - index = stringImpl::stringPosition(str, delim, count); + index = stringImpl::stringPosition( + std::string_view(str.data(), str.size()), + std::string_view(delim.data(), delim.size()), + count); } else { - index = stringImpl::stringPosition(str, delim, -count); + index = stringImpl::stringPosition( + std::string_view(str.data(), str.size()), + std::string_view(delim.data(), delim.size()), + -count); } // If 'delim' is not found or found fewer than 'count' times, diff --git a/velox/functions/sparksql/tests/StringTest.cpp b/velox/functions/sparksql/tests/StringTest.cpp index f06c698a69c6..419488eb5e5e 100644 --- a/velox/functions/sparksql/tests/StringTest.cpp +++ b/velox/functions/sparksql/tests/StringTest.cpp @@ -840,6 +840,45 @@ TEST_F(StringTest, substring) { EXPECT_EQ(substringWithLength("da\u6570\u636Eta", -3, 2), "\u636Et"); } +TEST_F(StringTest, locate) { + const auto locate = [&](const std::optional& substr, + const std::optional& str, + const std::optional& start) { + return evaluateOnce("locate(c0, c1, c2)", substr, str, start); + }; + + EXPECT_EQ(locate("aa", "aaads", 1), 1); + EXPECT_EQ(locate("aa", "aaads", 0), 0); + EXPECT_EQ(locate("aa", "aaads", 2), 2); + EXPECT_EQ(locate("aa", "aaads", 3), 0); + EXPECT_EQ(locate("aa", "aaads", -3), 0); + EXPECT_EQ(locate("de", "aaads", 1), 0); + EXPECT_EQ(locate("de", "aaads", 2), 0); + EXPECT_EQ(locate("abc", "abcdddabcabc", 6), 7); + EXPECT_EQ(locate("", "", 1), 1); + EXPECT_EQ(locate("", "", 3), 1); + EXPECT_EQ(locate("", "", -1), 0); + EXPECT_EQ(locate("", "aaads", 1), 1); + EXPECT_EQ(locate("", "aaads", 9), 1); + EXPECT_EQ(locate("", "aaads", -1), 0); + EXPECT_EQ(locate("aa", "", 1), 0); + EXPECT_EQ(locate("aa", "", 2), 0); + EXPECT_EQ(locate("zz", "aaads", std::nullopt), 0); + EXPECT_EQ(locate("aa", std::nullopt, 1), std::nullopt); + EXPECT_EQ(locate(std::nullopt, "aaads", 1), std::nullopt); + EXPECT_EQ(locate(std::nullopt, std::nullopt, -1), std::nullopt); + EXPECT_EQ(locate(std::nullopt, std::nullopt, std::nullopt), 0); + + EXPECT_EQ(locate("", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 10), 1); + EXPECT_EQ(locate("", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", -1), 0); + EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 1), 4); + EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 0), 0); + EXPECT_EQ( + locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 2), 4); + EXPECT_EQ( + locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 8), 0); +} + TEST_F(StringTest, substringIndex) { const auto substringIndex = [&](const std::string& str, const std::string& delim, int32_t count) {