Skip to content

Commit

Permalink
test encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 22, 2024
1 parent a4fa2a7 commit c72ae6f
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 62 deletions.
153 changes: 110 additions & 43 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,63 @@ class Split final : public exec::VectorFunction {
exec::VectorWriter<Array<Varchar>> resultWriter;
resultWriter.init(*result->as<ArrayVector>());

if (args[1]->isConstantEncoding()) {
// Fast path for pattern and limit being constant.
if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) {
// Adds brackets to the input pattern for sub-pattern extraction.
const auto pattern =
args[1]->asUnchecked<ConstantVector<StringView>>()->valueAt(0);
const auto limit =
args[1]->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
if (pattern.size() == 0) {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern(input->valueAt<StringView>(row), row, resultWriter);
});
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
});
}
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
});
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
});
}
}
} else {
exec::LocalDecodedVector patterns(context, *args[1], rows);
exec::LocalDecodedVector limits(context, *args[2], rows);

rows.applyToSelected([&](vector_size_t row) {
const auto pattern = patterns->valueAt<StringView>(row);
const auto limit = limits->valueAt<int32_t>(row);
if (pattern.size() == 0) {
splitEmptyPattern(input->valueAt<StringView>(row), row, resultWriter);
if (limit > 0) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
} else {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
}
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
splitAndWrite(input->valueAt<StringView>(row), re, row, resultWriter);
if (limit > 0) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
} else {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
}
}
});
}
Expand All @@ -75,73 +109,106 @@ class Split final : public exec::VectorFunction {
}

private:
// Split each character if the pattern is empty extraction.
// When pattern is empty, split each character.
template <bool limited>
void splitEmptyPattern(
const StringView current,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter) const {
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
const char* pos = current.begin();
const char* const begin = current.begin();
const char* const end = current.end();
do {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
} while (pos != end);
const char* pos = begin;
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
while (pos != end && pos - begin < limit - 1) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
}
} else {
while (pos != end) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
}
}
resultWriter.commit();
}

// Split input string with a non-empty pattern.
// Split with a non-empty pattern.
template <bool limited>
void splitAndWrite(
const StringView current,
const re2::RE2& re,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter) const {
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
const char* pos = current.begin();
const char* const end = current.end();
do {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
uint32_t numPieces = 0;
while (pos != end && numPieces < limit - 1) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
numPieces += 1;
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
pos = piece.end();
} else {
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
} while (pos != end);
} else {
while (pos != end) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
}
}
resultWriter.commit();
}
};

/// The function returns specialized version of split based on the constant
/// inputs.
/// \param inputArgs the inputs types (VARCHAR, VARCHAR, int64) and constant
/// values (if provided).
/// Returns split function.
/// @param inputArgs the inputs types (VARCHAR, VARCHAR, int32).
std::shared_ptr<exec::VectorFunction> createSplit(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_LE(
inputArgs.size(), 3, "The number of arguments should not exceed 3.");
VELOX_USER_CHECK_EQ(
inputArgs.size(), 3, "Three arguments are required for split function.");
VELOX_USER_CHECK(
inputArgs[0].type->isVarchar(),
"The first argument should be of varchar type.");
VELOX_USER_CHECK(
inputArgs[1].type->isVarchar(),
"The second argument should be of varchar type.");
// TODO: support the third argument.
if (inputArgs.size() > 2) {
VELOX_USER_CHECK(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
}
VELOX_USER_CHECK(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
return std::make_shared<Split>();
}

Expand All @@ -151,9 +218,9 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
.returnType("array(varchar)")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("integer")
.build()};
}

} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
Expand Down
79 changes: 60 additions & 19 deletions velox/functions/sparksql/tests/SplitFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,37 @@ class SplitTest : public SparkFunctionBaseTest {
void testSplit(
const std::vector<std::optional<std::string>>& input,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output,
int32_t limit = -1);

void testSplitEncodings(
const std::vector<VectorPtr>& inputs,
const std::vector<std::optional<std::vector<std::string>>>& output);

ArrayVectorPtr toArrayVector(
const std::vector<std::optional<std::vector<std::string>>>& vector);
};

ArrayVectorPtr SplitTest::toArrayVector(
const std::vector<std::optional<std::vector<std::string>>>& vector) {
// Creating vectors for output string vectors
auto sizeAt = [&vector](vector_size_t row) {
return vector[row] ? vector[row]->size() : 0;
};
auto valueAt = [&vector](vector_size_t row, vector_size_t idx) {
return vector[row] ? StringView(vector[row]->at(idx)) : StringView("");
};
auto nullAt = [&vector](vector_size_t row) {
return !vector[row].has_value();
};
return makeArrayVector<StringView>(vector.size(), sizeAt, valueAt, nullAt);
}

void SplitTest::testSplit(
const std::vector<std::optional<std::string>>& input,
std::optional<std::string> pattern,
const std::vector<std::optional<std::vector<std::string>>>& output) {
const std::vector<std::optional<std::vector<std::string>>>& output,
int32_t limit = -1) {
auto valueAt = [&input](vector_size_t row) {
return input[row] ? StringView(*input[row]) : StringView();
};
Expand All @@ -50,28 +74,31 @@ void SplitTest::testSplit(
std::string patternString = pattern.has_value()
? std::string(", '") + pattern.value() + "'"
: ", ''";
const std::string limitString = ", " + std::to_string(limit);
std::string expressionString =
std::string("split(c0") + patternString + ")";
std::string("split(c0") + patternString + limitString + ")";
return evaluate<ArrayVector>(expressionString, rowVector);
}();

// Creating vectors for output string vectors
auto sizeAtOutput = [&output](vector_size_t row) {
return output[row] ? output[row]->size() : 0;
};
auto valueAtOutput = [&output](vector_size_t row, vector_size_t idx) {
return output[row] ? StringView(output[row]->at(idx)) : StringView("");
};
auto nullAtOutput = [&output](vector_size_t row) {
return !output[row].has_value();
};
auto expectedResult = makeArrayVector<StringView>(
output.size(), sizeAtOutput, valueAtOutput, nullAtOutput);
const auto expectedResult = toArrayVector(output);

// Checking the results
assertEqualVectors(expectedResult, result);
}

void SplitTest::testSplitEncodings(
const std::vector<VectorPtr>& inputs,
const std::vector<std::optional<std::vector<std::string>>>& output) {
const auto expected = toArrayVector(output);
std::vector<core::TypedExprPtr> inputExprs = {
std::make_shared<core::FieldAccessTypedExpr>(inputs[0]->type(), "c0"),
std::make_shared<core::FieldAccessTypedExpr>(inputs[1]->type(), "c1"),
std::make_shared<core::FieldAccessTypedExpr>(inputs[2]->type(), "c2")};
const auto expr = std::make_shared<const core::CallTypedExpr>(
expected->type(), std::move(inputExprs), "split");
testEncodings(expr, inputs, expected);
}

TEST_F(SplitTest, reallocationAndCornerCases) {
testSplit(
{"boo:and:foo", "abcfd", "abcfd:", "", ":ab::cfd::::"},
Expand Down Expand Up @@ -114,11 +141,25 @@ TEST_F(SplitTest, zeroLengthPattern) {
{{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")"}}});
}

TEST_F(SplitTest, pattern) {
testSplit(
{"oneAtwoBthreeC", "oneAtwoBthreeCfourD"},
"[ABC]",
{{{"one", "two", "three", ""}}, {{"one", "two", "three", "fourD"}}});
TEST_F(SplitTest, encodings) {
const std::vector<VectorPtr> inputs = {
makeFlatVector<StringView>(
{"oneAtwoBthreeC",
"a chrisr:9000 here",
"hello",
"1001 nights",
"morning"}),
makeFlatVector<StringView>(
{"[ABC]", "((\\w+):([0-9]+))", "e.*o", "(\\d+)", "(mo)|ni"}),
makeFlatVector<int32_t>({-1, -1, 0, -1, -2}),
};
const std::vector<std::optional<std::vector<std::string>>> expected = {
{{"one", "two", "three", ""}},
{{"a ", " here"}},
{{"h", ""}},
{{"", "nights"}},
{{"", "r", "ng"}}};
testSplitEncodings(inputs, expected);
}

} // namespace
Expand Down

0 comments on commit c72ae6f

Please sign in to comment.