Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 23, 2024
1 parent dd6aec7 commit 90153ce
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 118 deletions.
13 changes: 8 additions & 5 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,22 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT rtrim('kr', 'spark'); -- "spa"

.. spark:function:: split(string, delimiter, limit) -> array(string)
.. spark:function:: split(string, delimiter[, limit]) -> array(string)
Splits ``string`` around occurrences that match ``delimiter`` and returns an array
with a length of at most ``limit``. ``delimiter`` is a string representing a regular
expression. ``limit`` is an integer which controls the number of times the regex is
applied. When ``limit`` > 0, the resulting array's length will not be more than
``limit``, and the resulting array's last entry will contain all input beyond the
last matched regex. When ``limit`` <= 0, ``regex`` will be applied as many times as
possible, and the resulting array can be of any size. ::
applied. By default, ``limit`` is -1. When ``limit`` > 0, the resulting array's
length will not be more than ``limit``, and the resulting array's last entry will
contain all input beyond the last matched regex. When ``limit`` <= 0, ``regex`` will
be applied as many times as possible, and the resulting array can be of any size. ::

SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""]
SELECT split('oneAtwoBthreeC', '[ABC]', 2); -- ["one","twoBthreeC"]
SELECT split('one', ''); -- ["o", "n", "e", ""]
SELECT split('one', '1'); -- ["one"]
SELECT split('abcd', ''); -- ["a", "b", "c", "d"]
SELECT split('abcd', '', 3); -- ["a", "b", "c"]

.. spark:function:: split(string, delimiter, limit) -> array(string)
:noindex:
Expand Down
184 changes: 80 additions & 104 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,64 +36,60 @@ class Split final : public exec::VectorFunction {
BaseVector::ensureWritable(rows, ARRAY(VARCHAR()), context.pool(), result);
exec::VectorWriter<Array<Varchar>> resultWriter;
resultWriter.init(*result->as<ArrayVector>());
int32_t limit = -1;

// Fast path for pattern and limit being constant.
if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) {
// Fast path for pattern and limit being constants.
if (args[1]->isConstantEncoding() &&
(args.size() == 2 || args[2]->isConstantEncoding())) {
// Adds brackets to the input pattern for sub-pattern extraction.
const auto pattern =
args[1]->asUnchecked<ConstantVector<StringView>>()->valueAt(0);
const auto limit =
args[2]->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
if (args.size() > 2) {
limit = args[2]->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
}
const auto positiveLimit =
limit > 0 ? limit : std::numeric_limits<uint32_t>::max();
if (pattern.size() == 0) {
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
});
}
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern(
input->valueAt<StringView>(row),
row,
resultWriter,
positiveLimit);
});
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
});
}
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite(
input->valueAt<StringView>(row),
re,
row,
resultWriter,
positiveLimit);
});
}
} else {
exec::LocalDecodedVector patterns(context, *args[1], rows);
exec::LocalDecodedVector limits(context, *args[2], rows);

rows.applyToSelected([&](vector_size_t row) {
const auto pattern = patterns->valueAt<StringView>(row);
const auto limit = limits->valueAt<int32_t>(row);
limit = limits->valueAt<int32_t>(row);
const auto positiveLimit =
limit > 0 ? limit : std::numeric_limits<uint32_t>::max();
if (pattern.size() == 0) {
if (limit > 0) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
} else {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
}
splitEmptyPattern(
input->valueAt<StringView>(row),
row,
resultWriter,
positiveLimit);
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
if (limit > 0) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
} else {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
}
splitAndWrite(
input->valueAt<StringView>(row),
re2::RE2("(" + pattern.str() + ")"),
row,
resultWriter,
positiveLimit);
}
});
}
Expand All @@ -108,12 +104,11 @@ class Split final : public exec::VectorFunction {

private:
// When pattern is empty, split each character.
template <bool limited>
void splitEmptyPattern(
const StringView current,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
uint32_t limit) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
if (current.size() == 0) {
Expand All @@ -125,32 +120,20 @@ class Split final : public exec::VectorFunction {
const char* const begin = current.begin();
const char* const end = current.end();
const char* pos = begin;
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
while (pos != end && pos - begin < limit - 1) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
}
} else {
while (pos != end) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
}
while (pos < end && pos < limit + begin) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
}
resultWriter.commit();
}

// Split with a non-empty pattern.
template <bool limited>
void splitAndWrite(
const StringView current,
const re2::RE2& re,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
uint32_t limit) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
if (current.size() == 0) {
Expand All @@ -161,42 +144,26 @@ class Split final : public exec::VectorFunction {

const char* pos = current.begin();
const char* const end = current.end();
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
uint32_t numPieces = 0;
while (pos != end && numPieces < limit - 1) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
numPieces += 1;
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
VELOX_DCHECK_GT(limit, 0);
uint32_t numPieces = 0;
while (pos < end && numPieces < limit - 1) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
numPieces += 1;
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
}
if (pos < end) {
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
} else {
while (pos != end) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
}
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
}
resultWriter.commit();
}
Expand All @@ -208,28 +175,37 @@ std::shared_ptr<exec::VectorFunction> createSplit(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_EQ(
inputArgs.size(), 3, "Three arguments are required for split function.");
VELOX_USER_CHECK(
inputArgs.size() == 2 || inputArgs.size() == 3,
"Two or three arguments are required for split function.");
VELOX_USER_CHECK(
inputArgs[0].type->isVarchar(),
"The first argument should be of varchar type.");
VELOX_USER_CHECK(
inputArgs[1].type->isVarchar(),
"The second argument should be of varchar type.");
VELOX_USER_CHECK(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
if (inputArgs.size() > 2) {
VELOX_USER_CHECK(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
}
return std::make_shared<Split>();
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// varchar, varchar -> array(varchar)
return {exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("integer")
.build()};
return {
exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.argumentType("varchar")
.build(),
exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("integer")
.build()};
}

} // namespace
Expand Down
26 changes: 17 additions & 9 deletions velox/functions/sparksql/tests/SplitFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SplitTest : public SparkFunctionBaseTest {
const std::vector<std::optional<std::string>>& input,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output,
int32_t limit = -1);
std::optional<int32_t> limit = std::nullopt);

void testSplitEncodings(
const std::vector<VectorPtr>& inputs,
Expand Down Expand Up @@ -57,7 +57,7 @@ void SplitTest::testSplit(
const std::vector<std::optional<std::string>>& input,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output,
int32_t limit) {
std::optional<int32_t> limit) {
auto valueAt = [&input](vector_size_t row) {
return input[row] ? StringView(*input[row]) : StringView();
};
Expand All @@ -74,8 +74,9 @@ void SplitTest::testSplit(
std::string patternString = pattern.has_value()
? std::string(", '") + pattern.value() + "'"
: ", ''";
const std::string limitString =
", '" + std::to_string(limit) + "'::INTEGER";
const std::string limitString = limit.has_value()
? ", '" + std::to_string(limit.value()) + "'::INTEGER"
: "";
std::string expressionString =
std::string("split(c0") + patternString + limitString + ")";
return evaluate<ArrayVector>(expressionString, rowVector);
Expand All @@ -93,8 +94,11 @@ void SplitTest::testSplitEncodings(
const auto expected = toArrayVector(output);
std::vector<core::TypedExprPtr> inputExprs = {
std::make_shared<core::FieldAccessTypedExpr>(inputs[0]->type(), "c0"),
std::make_shared<core::FieldAccessTypedExpr>(inputs[1]->type(), "c1"),
std::make_shared<core::FieldAccessTypedExpr>(inputs[2]->type(), "c2")};
std::make_shared<core::FieldAccessTypedExpr>(inputs[1]->type(), "c1")};
if (inputs.size() > 2) {
inputExprs.emplace_back(
std::make_shared<core::FieldAccessTypedExpr>(inputs[2]->type(), "c2"));
}
const auto expr = std::make_shared<const core::CallTypedExpr>(
expected->type(), std::move(inputExprs), "split");
testEncodings(expr, inputs, expected);
Expand Down Expand Up @@ -143,10 +147,13 @@ TEST_F(SplitTest, zeroLengthPattern) {
{{{"a", "b", "c", "d", "e", "f", "g"}},
{{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")"}},
{{""}}});

// The result does not include remaining string when limit is smaller than the
// string size.
testSplit(
{"abcdefg", "abc:+%/n?(^)", ""},
{"abcdefg", "ab:c+%/n?(^)", ""},
std::nullopt,
{{{"a", "b", "cdefg"}}, {{"a", "b", "c:+%/n?(^)"}}, {{""}}},
{{{"a", "b", "c"}}, {{"a", "b", ":"}}, {{""}}},
3);
testSplit(
{"abcdefg", "abc:+%/n?(^)", ""},
Expand Down Expand Up @@ -200,10 +207,11 @@ TEST_F(SplitTest, encodings) {
}},
{{"", ""}}};
testSplitEncodings({strings, patterns, limits}, expected);
testSplitEncodings({strings, patterns}, expected);

limits = makeFlatVector<int32_t>({3, 3, 2, 1, 5, 2, 1, 1, 2, 2});
expected = {
{{"a", "b", "cdef"}},
{{"a", "b", "c"}},
{{"one", "two", "threeC"}},
{{"aa", "bb3cc"}},
{{"hello"}},
Expand Down

0 comments on commit 90153ce

Please sign in to comment.