Skip to content

Commit

Permalink
empty pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 22, 2024
1 parent 200159a commit a4fa2a7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 42 deletions.
78 changes: 50 additions & 28 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,28 @@ class Split final : public exec::VectorFunction {

if (args[1]->isConstantEncoding()) {
// Adds brackets to the input pattern for sub-pattern extraction.
const auto re = re2::RE2(
"(" +
args[1]->asUnchecked<ConstantVector<StringView>>()->valueAt(0).str() +
")");
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
});
const auto pattern =
args[1]->asUnchecked<ConstantVector<StringView>>()->valueAt(0);
if (pattern.size() == 0) {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern(input->valueAt<StringView>(row), row, resultWriter);
});
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
});
}
} else {
exec::LocalDecodedVector patterns(context, *args[1], rows);
rows.applyToSelected([&](vector_size_t row) {
const auto re =
re2::RE2("(" + patterns->valueAt<StringView>(row).str() + ")");
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
const auto pattern = patterns->valueAt<StringView>(row);
if (pattern.size() == 0) {
splitEmptyPattern(input->valueAt<StringView>(row), row, resultWriter);
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
}
});
}
resultWriter.finish();
Expand All @@ -66,6 +75,23 @@ class Split final : public exec::VectorFunction {
}

private:
// Split each character if the pattern is empty extraction.
void splitEmptyPattern(
const StringView current,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
const char* pos = current.begin();
const char* const end = current.end();
do {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
} while (pos != end);
resultWriter.commit();
}

// Split input string with a non-empty pattern.
void splitAndWrite(
const StringView current,
const re2::RE2& re,
Expand All @@ -75,24 +101,21 @@ class Split final : public exec::VectorFunction {
auto& arrayWriter = resultWriter.current();
const char* pos = current.begin();
const char* const end = current.end();
{
do {
re2::StringPiece piece;
if (re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
do {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
} while (pos != end);
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
} while (pos != end);
resultWriter.commit();
}
};
Expand All @@ -119,7 +142,6 @@ std::shared_ptr<exec::VectorFunction> createSplit(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
}
// TODO: Add support for zero-length pattern.
return std::make_shared<Split>();
}

Expand Down
44 changes: 30 additions & 14 deletions velox/functions/sparksql/tests/SplitFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ namespace {

class SplitTest : public SparkFunctionBaseTest {
protected:
void testSplitCharacter(
void testSplit(
const std::vector<std::optional<std::string>>& input,
std::optional<char> pattern,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output);
};

void SplitTest::testSplitCharacter(
void SplitTest::testSplit(
const std::vector<std::optional<std::string>>& input,
std::optional<char> pattern,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output) {
auto valueAt = [&input](vector_size_t row) {
return input[row] ? StringView(*input[row]) : StringView();
Expand All @@ -47,8 +47,9 @@ void SplitTest::testSplitCharacter(
auto rowVector = makeRowVector({inputString});

// Evaluating the function for each input and seed
std::string patternString =
pattern.has_value() ? std::string(", '") + pattern.value() + "'" : "";
std::string patternString = pattern.has_value()
? std::string(", '") + pattern.value() + "'"
: ", ''";
std::string expressionString =
std::string("split(c0") + patternString + ")";
return evaluate<ArrayVector>(expressionString, rowVector);
Expand All @@ -72,9 +73,9 @@ void SplitTest::testSplitCharacter(
}

TEST_F(SplitTest, reallocationAndCornerCases) {
testSplitCharacter(
testSplit(
{"boo:and:foo", "abcfd", "abcfd:", "", ":ab::cfd::::"},
':',
":",
{{{"boo", "and", "foo"}},
{{"abcfd"}},
{{"abcfd", ""}},
Expand All @@ -83,9 +84,9 @@ TEST_F(SplitTest, reallocationAndCornerCases) {
}

TEST_F(SplitTest, nulls) {
testSplitCharacter(
testSplit(
{std::nullopt, "abcfd", "abcfd:", std::nullopt, ":ab::cfd::::"},
':',
":",
{{std::nullopt},
{{"abcfd"}},
{{"abcfd", ""}},
Expand All @@ -94,16 +95,31 @@ TEST_F(SplitTest, nulls) {
}

TEST_F(SplitTest, defaultArguments) {
testSplitCharacter(
{"boo:and:foo", "abcfd"}, ':', {{{"boo", "and", "foo"}}, {{"abcfd"}}});
testSplit(
{"boo:and:foo", "abcfd"}, ":", {{{"boo", "and", "foo"}}, {{"abcfd"}}});
}

TEST_F(SplitTest, longStrings) {
testSplitCharacter(
testSplit(
{"abcdefghijklkmnopqrstuvwxyz"},
',',
",",
{{{"abcdefghijklkmnopqrstuvwxyz"}}});
}

TEST_F(SplitTest, zeroLengthPattern) {
testSplit(
{"abcdefg", "abc:+%/n?(^)"},
std::nullopt,
{{{"a", "b", "c", "d", "e", "f", "g"}},
{{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")"}}});
}

TEST_F(SplitTest, pattern) {
testSplit(
{"oneAtwoBthreeC", "oneAtwoBthreeCfourD"},
"[ABC]",
{{{"one", "two", "three", ""}}, {{"one", "two", "three", "fourD"}}});
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit a4fa2a7

Please sign in to comment.