Skip to content

Commit

Permalink
Support locate function
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 27, 2024
1 parent 52ad7f5 commit 74fe778
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
6 changes: 5 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,17 @@ void registerFunctions(const std::string& prefix) {
registerCompareFunctions(prefix);
registerBitwiseFunctions(prefix);

// String sreach function
// String search function
registerFunction<StartsWithFunction, bool, Varchar, Varchar>(
{prefix + "startswith"});
registerFunction<EndsWithFunction, bool, Varchar, Varchar>(
{prefix + "endswith"});
registerFunction<ContainsFunction, bool, Varchar, Varchar>(
{prefix + "contains"});
registerFunction<LocateFunction, int32_t, Varchar, Varchar>(
{prefix + "locate"});
registerFunction<LocateFunction, int32_t, Varchar, Varchar, int32_t>(
{prefix + "locate"});

registerFunction<TrimSpaceFunction, Varchar, Varchar>({prefix + "trim"});
registerFunction<TrimFunction, Varchar, Varchar, Varchar>({prefix + "trim"});
Expand Down
49 changes: 48 additions & 1 deletion velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ struct StartsWithFunction {
result = false;
} else {
result = str1.substr(0, str2.length()) == str2;
;
}
return true;
}
Expand Down Expand Up @@ -284,6 +283,54 @@ struct EndsWithFunction {
}
};

/// locate function
/// locate(string, string) -> integer
/// locate(string, string, integer) -> integer
/// Returns the position of the first occurrence of 'substr' in given string
/// after position 'start'.
template <typename T>
struct LocateFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool doCall(
out_type<int32_t>& result,
const arg_type<Varchar>* substr,
const arg_type<Varchar>* str,
int32_t start) {
if (substr == nullptr || str == nullptr) {
return false;
}
const auto inputStr = (*str).str();
std::size_t found =
inputStr.find_first_of((*substr).data(), start - 1, (*substr).size());
if (found != std::string::npos) {
result = found + 1;
} else {
result = 0;
}
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<int32_t>& result,
const arg_type<Varchar>* substr,
const arg_type<Varchar>* str) {
return doCall(result, substr, str, 1);
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<int32_t>& result,
const arg_type<Varchar>* substr,
const arg_type<Varchar>* str,
const arg_type<int32_t>* start) {
if (start == nullptr || *start < 1) {
result = 0;
return true;
}
return doCall(result, substr, str, *start);
}
};

/// Returns the substring from str before count occurrences of the delimiter
/// delim. If count is positive, everything to the left of the final delimiter
/// (counting from the left) is returned. If count is negative, everything to
Expand Down
33 changes: 33 additions & 0 deletions velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ class StringTest : public SparkFunctionBaseTest {
return evaluateOnce<bool>("contains(c0, c1)", str, pattern);
}

std::optional<int32_t> locate(
const std::optional<std::string>& substr,
const std::optional<std::string>& str) {
return evaluateOnce<int32_t>("locate(c0, c1)", substr, str);
}

std::optional<int32_t> locate(
const std::optional<std::string>& substr,
const std::optional<std::string>& str,
const std::optional<int32_t>& start) {
return evaluateOnce<int32_t>("locate(c0, c1, c2)", substr, str, start);
}

std::optional<std::string> substring(
std::optional<std::string> str,
std::optional<int32_t> start) {
Expand Down Expand Up @@ -405,6 +418,26 @@ TEST_F(StringTest, endsWith) {
EXPECT_EQ(endsWith(std::nullopt, "abc"), std::nullopt);
}

TEST_F(StringTest, locate) {
EXPECT_EQ(locate("aa", "aaads"), 1);
EXPECT_EQ(locate("aa", "aaads", 0), 0);
EXPECT_EQ(locate("aa", "aaads", 1), 1);
EXPECT_EQ(locate("aa", "aaads", 2), 2);
EXPECT_EQ(locate("aa", "aaads", 3), 0);
EXPECT_EQ(locate("de", "aaads"), 0);
EXPECT_EQ(locate("de", "aaads", 2), 0);
EXPECT_EQ(locate("", ""), 1);
EXPECT_EQ(locate("", "", 3), 1);
EXPECT_EQ(locate("", "aaads"), 0);
EXPECT_EQ(locate("", "aaads", 2), 0);
EXPECT_EQ(locate("aa", ""), 0);
EXPECT_EQ(locate("aa", "", 2), 0);
EXPECT_EQ(locate("zz", "aaads", std::nullopt), 0);
EXPECT_EQ(locate("aa", std::nullopt), std::nullopt);
EXPECT_EQ(locate(std::nullopt, "aaads"), std::nullopt);
EXPECT_EQ(locate(std::nullopt, std::nullopt, std::nullopt), 0);
}

TEST_F(StringTest, substringIndex) {
EXPECT_EQ(substringIndex("www.apache.org", ".", 3), "www.apache.org");
EXPECT_EQ(substringIndex("www.apache.org", ".", 2), "www.apache");
Expand Down

0 comments on commit 74fe778

Please sign in to comment.