From 74fe778476c50e1d94ba92cda3559775898fcc4c Mon Sep 17 00:00:00 2001 From: rui-mo Date: Mon, 26 Feb 2024 16:43:30 +0800 Subject: [PATCH] Support locate function --- velox/functions/sparksql/Register.cpp | 6 ++- velox/functions/sparksql/String.h | 49 ++++++++++++++++++- velox/functions/sparksql/tests/StringTest.cpp | 33 +++++++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index bd748af7dadfd..ebff668609b42 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -221,13 +221,17 @@ void registerFunctions(const std::string& prefix) { registerCompareFunctions(prefix); registerBitwiseFunctions(prefix); - // String sreach function + // String search function registerFunction( {prefix + "startswith"}); registerFunction( {prefix + "endswith"}); registerFunction( {prefix + "contains"}); + registerFunction( + {prefix + "locate"}); + registerFunction( + {prefix + "locate"}); registerFunction({prefix + "trim"}); registerFunction({prefix + "trim"}); diff --git a/velox/functions/sparksql/String.h b/velox/functions/sparksql/String.h index ad135b71ee76c..af1f5160ddaad 100644 --- a/velox/functions/sparksql/String.h +++ b/velox/functions/sparksql/String.h @@ -253,7 +253,6 @@ struct StartsWithFunction { result = false; } else { result = str1.substr(0, str2.length()) == str2; - ; } return true; } @@ -284,6 +283,54 @@ struct EndsWithFunction { } }; +/// locate function +/// locate(string, string) -> integer +/// locate(string, string, integer) -> integer +/// Returns the position of the first occurrence of 'substr' in given string +/// after position 'start'. +template +struct LocateFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool doCall( + out_type& result, + const arg_type* substr, + const arg_type* str, + int32_t start) { + if (substr == nullptr || str == nullptr) { + return false; + } + const auto inputStr = (*str).str(); + std::size_t found = + inputStr.find_first_of((*substr).data(), start - 1, (*substr).size()); + if (found != std::string::npos) { + result = found + 1; + } else { + result = 0; + } + return true; + } + + FOLLY_ALWAYS_INLINE bool callNullable( + out_type& result, + const arg_type* substr, + const arg_type* str) { + return doCall(result, substr, str, 1); + } + + FOLLY_ALWAYS_INLINE bool callNullable( + out_type& result, + const arg_type* substr, + const arg_type* str, + const arg_type* start) { + if (start == nullptr || *start < 1) { + result = 0; + return true; + } + return doCall(result, substr, str, *start); + } +}; + /// Returns the substring from str before count occurrences of the delimiter /// delim. If count is positive, everything to the left of the final delimiter /// (counting from the left) is returned. If count is negative, everything to diff --git a/velox/functions/sparksql/tests/StringTest.cpp b/velox/functions/sparksql/tests/StringTest.cpp index dc2bfadf619f5..ceefac436632f 100644 --- a/velox/functions/sparksql/tests/StringTest.cpp +++ b/velox/functions/sparksql/tests/StringTest.cpp @@ -120,6 +120,19 @@ class StringTest : public SparkFunctionBaseTest { return evaluateOnce("contains(c0, c1)", str, pattern); } + std::optional locate( + const std::optional& substr, + const std::optional& str) { + return evaluateOnce("locate(c0, c1)", substr, str); + } + + std::optional locate( + const std::optional& substr, + const std::optional& str, + const std::optional& start) { + return evaluateOnce("locate(c0, c1, c2)", substr, str, start); + } + std::optional substring( std::optional str, std::optional start) { @@ -405,6 +418,26 @@ TEST_F(StringTest, endsWith) { EXPECT_EQ(endsWith(std::nullopt, "abc"), std::nullopt); } +TEST_F(StringTest, locate) { + EXPECT_EQ(locate("aa", "aaads"), 1); + EXPECT_EQ(locate("aa", "aaads", 0), 0); + EXPECT_EQ(locate("aa", "aaads", 1), 1); + EXPECT_EQ(locate("aa", "aaads", 2), 2); + EXPECT_EQ(locate("aa", "aaads", 3), 0); + EXPECT_EQ(locate("de", "aaads"), 0); + EXPECT_EQ(locate("de", "aaads", 2), 0); + EXPECT_EQ(locate("", ""), 1); + EXPECT_EQ(locate("", "", 3), 1); + EXPECT_EQ(locate("", "aaads"), 0); + EXPECT_EQ(locate("", "aaads", 2), 0); + EXPECT_EQ(locate("aa", ""), 0); + EXPECT_EQ(locate("aa", "", 2), 0); + EXPECT_EQ(locate("zz", "aaads", std::nullopt), 0); + EXPECT_EQ(locate("aa", std::nullopt), std::nullopt); + EXPECT_EQ(locate(std::nullopt, "aaads"), std::nullopt); + EXPECT_EQ(locate(std::nullopt, std::nullopt, std::nullopt), 0); +} + TEST_F(StringTest, substringIndex) { EXPECT_EQ(substringIndex("www.apache.org", ".", 3), "www.apache.org"); EXPECT_EQ(substringIndex("www.apache.org", ".", 2), "www.apache");