Skip to content

Commit

Permalink
Add replace SparkSQL function (#4922)
Browse files Browse the repository at this point in the history
Summary:
The `replace` function behavior is different in Spark and Presto.

`replace(str, replaced, replacement)`, when `replaced` is empty:

- Presto inserts `replacement` before and after each character of `str`
- Spark returns the original `str` string

```sql
replace("aa", "", "x") -- "xaxax" (Presto)
replace("aa", "", "x") -- "aa" (Spark)
```

> [SparkSQL replace:](https://github.com/apache/spark/blob/39d43e0ac3b58fb7e804362bb07665e8d6536250/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L1058)
>
> ```scala
> public UTF8String replace(UTF8String search, UTF8String replace) {
>  // This implementation is loosely based on commons-lang3's StringUtils.replace().
>  if (numBytes == 0 || search.numBytes == 0) {
>    return this;
>  }
>  // Find the first occurrence of the search string.
>  int start = 0;
>  int end = this.find(search, start);
>  if (end == -1) {
>    // Search string was not found, so string is unchanged.
>    return this;
>  }
>  // At least one match was found. Estimate space needed for result.
>  // The 16x multiplier here is chosen to match commons-lang3's implementation.
>  int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
>  final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
>  while (end != -1) {
>    buf.appendBytes(this.base, this.offset + start, end - start);
>    buf.append(replace);
>    start = end + search.numBytes;
>    end = this.find(search, start);
>  }
>  buf.appendBytes(this.base, this.offset + start, numBytes - start);
>  return buf.build();
> }
> ```

Pull Request resolved: #4922

Reviewed By: xiaoxmeng

Differential Revision: D50044543

Pulled By: kgpai

fbshipit-source-id: abee68e5ea7529be60b4f7a3eb5c1f2feb5521ff
  • Loading branch information
izchen authored and facebook-github-bot committed Jan 9, 2024
1 parent 3ac35c9 commit 005e545
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 7 deletions.
15 changes: 13 additions & 2 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,21 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
SELECT overlay('Spark SQL', 'tructured', 2, 4); -- "Structured SQL"
SELECT overlay('Spark SQL', '_', -6, 3); -- "_Sql"

.. spark:function:: replace(string, search, replace) -> string
.. spark:function:: replace(input, replaced) -> varchar
Replaces all occurrences of `search` with `replace`. ::
Removes all instances of ``replaced`` from ``input``.
If ``replaced`` is an empty string, returns the original ``input`` string. ::

SELECT replace('ABCabc', ''); -- ABCabc
SELECT replace('ABCabc', 'bc'); -- ABCc

.. spark:function:: replace(input, replaced, replacement) -> varchar
Replaces all instances of ``replaced`` with ``replacement`` in ``input``.
If ``replaced`` is an empty string, returns the original ``input`` string. ::

SELECT replace('ABCabc', '', 'DEF'); -- ABCabc
SELECT replace('ABCabc', 'abc', ''); -- ABC
SELECT replace('ABCabc', 'abc', 'DEF'); -- ABCDEF

.. spark:function:: rpad(string, len, pad) -> string
Expand Down
21 changes: 17 additions & 4 deletions velox/functions/lib/string/StringCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ cappedLengthUnicode(const char* input, size_t size, size_t maxChars) {
/// string. Search starts from startPosition. Positions start with 0. If not
/// found, -1 is returned. To facilitate finding overlapping strings, the
/// nextStartPosition is incremented by 1
static int64_t findNthInstanceByteIndexFromStart(
static inline int64_t findNthInstanceByteIndexFromStart(
const std::string_view& string,
const std::string_view subString,
const size_t instance = 1,
Expand Down Expand Up @@ -327,10 +327,14 @@ inline int64_t findNthInstanceByteIndexFromEnd(

/// Replace replaced with replacement in inputString and write results in
/// outputString. If inPlace=true inputString and outputString are assumed to
/// tbe the same. When replaced is empty, replacement is added before and after
/// each charecter. When inputString is empty results is empty.
/// tbe the same. When replaced is empty and ignoreEmptyReplaced is false,
/// replacement is added before and after each charecter. When replaced is
/// empty and ignoreEmptyReplaced is true, the result is the inputString value.
/// When inputString is empty results is empty.
/// replace("", "", "x") = ""
/// replace("aa", "", "x") = "xaxax"
/// replace("aa", "", "x") = "xaxax" -- when ignoreEmptyReplaced is false
/// replace("aa", "", "x") = "aa" -- when ignoreEmptyReplaced is true
template <bool ignoreEmptyReplaced = false>
inline static size_t replace(
char* outputString,
const std::string_view& inputString,
Expand All @@ -341,6 +345,15 @@ inline static size_t replace(
return 0;
}

if constexpr (ignoreEmptyReplaced) {
if (replaced.size() == 0) {
if (!inPlace) {
std::memcpy(outputString, inputString.data(), inputString.size());
}
return inputString.size();
}
}

size_t readPosition = 0;
size_t writePosition = 0;
// Copy needed in out of place replace, and when replaced and replacement are
Expand Down
6 changes: 5 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ static void workAroundRegistrationMacro(const std::string& prefix) {
// String functions.
VELOX_REGISTER_VECTOR_FUNCTION(udf_concat, prefix + "concat");
VELOX_REGISTER_VECTOR_FUNCTION(udf_lower, prefix + "lower");
VELOX_REGISTER_VECTOR_FUNCTION(udf_replace, prefix + "replace");
VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, prefix + "upper");
// Logical.
VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not");
Expand Down Expand Up @@ -239,6 +238,11 @@ void registerFunctions(const std::string& prefix) {
registerFunction<ConvFunction, Varchar, Varchar, int32_t, int32_t>(
{prefix + "conv"});

registerFunction<ReplaceFunction, Varchar, Varchar, Varchar>(
{prefix + "replace"});
registerFunction<ReplaceFunction, Varchar, Varchar, Varchar, Varchar>(
{prefix + "replace"});

// Register array sort functions.
exec::registerStatefulVectorFunction(
prefix + "array_sort", arraySortSignatures(), makeArraySort);
Expand Down
52 changes: 52 additions & 0 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Macros.h"
#include "velox/functions/UDFOutputString.h"
#include "velox/functions/lib/string/StringCore.h"
#include "velox/functions/lib/string/StringImpl.h"

namespace facebook::velox::functions::sparksql {
Expand Down Expand Up @@ -1042,4 +1043,55 @@ struct ConvFunction {
}
};

/// replace(input, replaced) -> varchar
///
/// Removes all instances of ``replaced`` from ``input``.
/// If ``replaced`` is an empty string, returns the original ``input``
/// string.

///
/// replace(input, replaced, replacement) -> varchar
///
/// Replaces all instances of ``replaced`` with ``replacement`` in
/// ``input``. If ``replaced`` is an empty string, returns the original
/// ``input`` string.
template <typename T>
struct ReplaceFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& replaced) {
result.reserve(input.size());
auto resultSize = stringCore::replace<true /*ignoreEmptyReplaced*/>(
result.data(),
std::string_view(input.data(), input.size()),
std::string_view(replaced.data(), replaced.size()),
std::string_view(),
false);
result.resize(resultSize);
}

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& replaced,
const arg_type<Varchar>& replacement) {
size_t reserveSize = input.size();
if (replaced.size() != 0 && replacement.size() > replaced.size()) {
reserveSize = (input.size() / replaced.size()) * replacement.size() +
input.size() % replaced.size();
}
result.reserve(reserveSize);
auto resultSize = stringCore::replace<true /*ignoreEmptyReplaced*/>(
result.data(),
std::string_view(input.data(), input.size()),
std::string_view(replaced.data(), replaced.size()),
std::string_view(replacement.data(), replacement.size()),
false);
result.resize(resultSize);
}
};

} // namespace facebook::velox::functions::sparksql
29 changes: 29 additions & 0 deletions velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,20 @@ class StringTest : public SparkFunctionBaseTest {
std::optional<int32_t> toBase) {
return evaluateOnce<std::string>("conv(c0, c1, c2)", str, fromBase, toBase);
}

std::optional<std::string> replace(
std::optional<std::string> str,
std::optional<std::string> replaced) {
return evaluateOnce<std::string>("replace(c0, c1)", str, replaced);
}

std::optional<std::string> replace(
std::optional<std::string> str,
std::optional<std::string> replaced,
std::optional<std::string> replacement) {
return evaluateOnce<std::string>(
"replace(c0, c1, c2)", str, replaced, replacement);
}
};

TEST_F(StringTest, Ascii) {
Expand Down Expand Up @@ -772,5 +786,20 @@ TEST_F(StringTest, conv) {
EXPECT_EQ(conv("", std::nullopt, 16), std::nullopt);
EXPECT_EQ(conv("", 10, std::nullopt), std::nullopt);
}

TEST_F(StringTest, replace) {
EXPECT_EQ(replace("aaabaac", "a"), "bc");
EXPECT_EQ(replace("aaabaac", ""), "aaabaac");
EXPECT_EQ(replace("aaabaac", "a", "z"), "zzzbzzc");
EXPECT_EQ(replace("aaabaac", "", "z"), "aaabaac");
EXPECT_EQ(replace("aaabaac", "a", ""), "bc");
EXPECT_EQ(replace("aaabaac", "x", "z"), "aaabaac");
EXPECT_EQ(replace("aaabaac", "aaa", "z"), "zbaac");
EXPECT_EQ(replace("aaabaac", "a", "xyz"), "xyzxyzxyzbxyzxyzc");
EXPECT_EQ(replace("aaabaac", "aaabaac", "z"), "z");
EXPECT_EQ(
replace("123\u6570\u6570\u636E", "\u6570\u636E", "data"),
"123\u6570data");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 005e545

Please sign in to comment.