diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index 3e1405b50a3ee..6715c52d6e427 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -39,29 +39,63 @@ class Split final : public exec::VectorFunction { exec::VectorWriter> resultWriter; resultWriter.init(*result->as()); - if (args[1]->isConstantEncoding()) { + // Fast path for pattern and limit being constant. + if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) { // Adds brackets to the input pattern for sub-pattern extraction. const auto pattern = args[1]->asUnchecked>()->valueAt(0); + const auto limit = + args[1]->asUnchecked>()->valueAt(0); if (pattern.size() == 0) { - rows.applyToSelected([&](vector_size_t row) { - splitEmptyPattern(input->valueAt(row), row, resultWriter); - }); + if (limit > 0) { + rows.applyToSelected([&](vector_size_t row) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter, limit); + }); + } else { + rows.applyToSelected([&](vector_size_t row) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter); + }); + } } else { const auto re = re2::RE2("(" + pattern.str() + ")"); - rows.applyToSelected([&](vector_size_t row) { - splitAndWrite(input->valueAt(row), re, row, resultWriter); - }); + if (limit > 0) { + rows.applyToSelected([&](vector_size_t row) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter, limit); + }); + } else { + rows.applyToSelected([&](vector_size_t row) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter); + }); + } } } else { exec::LocalDecodedVector patterns(context, *args[1], rows); + exec::LocalDecodedVector limits(context, *args[2], rows); + rows.applyToSelected([&](vector_size_t row) { const auto pattern = patterns->valueAt(row); + const auto limit = limits->valueAt(row); if (pattern.size() == 0) { - splitEmptyPattern(input->valueAt(row), row, resultWriter); + if (limit > 0) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter, limit); + } else { + splitEmptyPattern( + input->valueAt(row), row, resultWriter); + } } else { const auto re = re2::RE2("(" + pattern.str() + ")"); - splitAndWrite(input->valueAt(row), re, row, resultWriter); + if (limit > 0) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter, limit); + } else { + splitAndWrite( + input->valueAt(row), re, row, resultWriter); + } } }); } @@ -75,73 +109,106 @@ class Split final : public exec::VectorFunction { } private: - // Split each character if the pattern is empty extraction. + // When pattern is empty, split each character. + template void splitEmptyPattern( const StringView current, vector_size_t row, - exec::VectorWriter>& resultWriter) const { + exec::VectorWriter>& resultWriter, + uint32_t limit = 0) const { resultWriter.setOffset(row); auto& arrayWriter = resultWriter.current(); - const char* pos = current.begin(); + const char* const begin = current.begin(); const char* const end = current.end(); - do { - arrayWriter.add_item().setNoCopy(StringView(pos, 1)); - pos += 1; - } while (pos != end); + const char* pos = begin; + if constexpr (limited) { + VELOX_DCHECK_GT(limit, 0); + while (pos != end && pos - begin < limit - 1) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; + } + if (pos < end) { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + } + } else { + while (pos != end) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; + } + } resultWriter.commit(); } - // Split input string with a non-empty pattern. + // Split with a non-empty pattern. + template void splitAndWrite( const StringView current, const re2::RE2& re, vector_size_t row, - exec::VectorWriter>& resultWriter) const { + exec::VectorWriter>& resultWriter, + uint32_t limit = 0) const { resultWriter.setOffset(row); auto& arrayWriter = resultWriter.current(); const char* pos = current.begin(); const char* const end = current.end(); - do { - if (re2::StringPiece piece; re2::RE2::PartialMatch( - re2::StringPiece(pos, end - pos), re, &piece)) { - arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); - if (piece.end() == end) { - // When the found delimiter is at the end of input string, keeps - // one empty piece of string. - arrayWriter.add_item().setNoCopy(StringView()); + if constexpr (limited) { + VELOX_DCHECK_GT(limit, 0); + uint32_t numPieces = 0; + while (pos != end && numPieces < limit - 1) { + if (re2::StringPiece piece; re2::RE2::PartialMatch( + re2::StringPiece(pos, end - pos), re, &piece)) { + arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); + numPieces += 1; + if (piece.end() == end) { + // When the found delimiter is at the end of input string, keeps + // one empty piece of string. + arrayWriter.add_item().setNoCopy(StringView()); + } + pos = piece.end(); + } else { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + pos = end; } - pos = piece.end(); - } else { + } + if (pos < end) { arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); - pos = end; } - } while (pos != end); + } else { + while (pos != end) { + if (re2::StringPiece piece; re2::RE2::PartialMatch( + re2::StringPiece(pos, end - pos), re, &piece)) { + arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); + if (piece.end() == end) { + arrayWriter.add_item().setNoCopy(StringView()); + } + pos = piece.end(); + } else { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + pos = end; + } + } + } resultWriter.commit(); } }; -/// The function returns specialized version of split based on the constant -/// inputs. -/// \param inputArgs the inputs types (VARCHAR, VARCHAR, int64) and constant -/// values (if provided). +/// Returns split function. +/// @param inputArgs the inputs types (VARCHAR, VARCHAR, int32). std::shared_ptr createSplit( const std::string& /*name*/, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { - VELOX_USER_CHECK_LE( - inputArgs.size(), 3, "The number of arguments should not exceed 3."); + VELOX_USER_CHECK_EQ( + inputArgs.size(), 3, "Three arguments are required for split function."); VELOX_USER_CHECK( inputArgs[0].type->isVarchar(), "The first argument should be of varchar type."); VELOX_USER_CHECK( inputArgs[1].type->isVarchar(), "The second argument should be of varchar type."); - // TODO: support the third argument. - if (inputArgs.size() > 2) { - VELOX_USER_CHECK( - inputArgs[2].type->kind() == TypeKind::INTEGER, - "The third argument should be of integer type."); - } + VELOX_USER_CHECK( + inputArgs[2].type->kind() == TypeKind::INTEGER, + "The third argument should be of integer type."); return std::make_shared(); } @@ -151,9 +218,9 @@ std::vector> signatures() { .returnType("array(varchar)") .argumentType("varchar") .argumentType("varchar") + .argumentType("integer") .build()}; } - } // namespace VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index 94435d8cf7f62..b951e72a87423 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -27,13 +27,37 @@ class SplitTest : public SparkFunctionBaseTest { void testSplit( const std::vector>& input, std::optional pattern, + const std::vector>>& output, + int32_t limit = -1); + + void testSplitEncodings( + const std::vector& inputs, const std::vector>>& output); + + ArrayVectorPtr toArrayVector( + const std::vector>>& vector); }; +ArrayVectorPtr SplitTest::toArrayVector( + const std::vector>>& vector) { + // Creating vectors for output string vectors + auto sizeAt = [&vector](vector_size_t row) { + return vector[row] ? vector[row]->size() : 0; + }; + auto valueAt = [&vector](vector_size_t row, vector_size_t idx) { + return vector[row] ? StringView(vector[row]->at(idx)) : StringView(""); + }; + auto nullAt = [&vector](vector_size_t row) { + return !vector[row].has_value(); + }; + return makeArrayVector(vector.size(), sizeAt, valueAt, nullAt); +} + void SplitTest::testSplit( const std::vector>& input, std::optional pattern, - const std::vector>>& output) { + const std::vector>>& output, + int32_t limit = -1) { auto valueAt = [&input](vector_size_t row) { return input[row] ? StringView(*input[row]) : StringView(); }; @@ -50,28 +74,31 @@ void SplitTest::testSplit( std::string patternString = pattern.has_value() ? std::string(", '") + pattern.value() + "'" : ", ''"; + const std::string limitString = ", " + std::to_string(limit); std::string expressionString = - std::string("split(c0") + patternString + ")"; + std::string("split(c0") + patternString + limitString + ")"; return evaluate(expressionString, rowVector); }(); - // Creating vectors for output string vectors - auto sizeAtOutput = [&output](vector_size_t row) { - return output[row] ? output[row]->size() : 0; - }; - auto valueAtOutput = [&output](vector_size_t row, vector_size_t idx) { - return output[row] ? StringView(output[row]->at(idx)) : StringView(""); - }; - auto nullAtOutput = [&output](vector_size_t row) { - return !output[row].has_value(); - }; - auto expectedResult = makeArrayVector( - output.size(), sizeAtOutput, valueAtOutput, nullAtOutput); + const auto expectedResult = toArrayVector(output); // Checking the results assertEqualVectors(expectedResult, result); } +void SplitTest::testSplitEncodings( + const std::vector& inputs, + const std::vector>>& output) { + const auto expected = toArrayVector(output); + std::vector inputExprs = { + std::make_shared(inputs[0]->type(), "c0"), + std::make_shared(inputs[1]->type(), "c1"), + std::make_shared(inputs[2]->type(), "c2")}; + const auto expr = std::make_shared( + expected->type(), std::move(inputExprs), "split"); + testEncodings(expr, inputs, expected); +} + TEST_F(SplitTest, reallocationAndCornerCases) { testSplit( {"boo:and:foo", "abcfd", "abcfd:", "", ":ab::cfd::::"}, @@ -114,11 +141,25 @@ TEST_F(SplitTest, zeroLengthPattern) { {{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")"}}}); } -TEST_F(SplitTest, pattern) { - testSplit( - {"oneAtwoBthreeC", "oneAtwoBthreeCfourD"}, - "[ABC]", - {{{"one", "two", "three", ""}}, {{"one", "two", "three", "fourD"}}}); +TEST_F(SplitTest, encodings) { + const std::vector inputs = { + makeFlatVector( + {"oneAtwoBthreeC", + "a chrisr:9000 here", + "hello", + "1001 nights", + "morning"}), + makeFlatVector( + {"[ABC]", "((\\w+):([0-9]+))", "e.*o", "(\\d+)", "(mo)|ni"}), + makeFlatVector({-1, -1, 0, -1, -2}), + }; + const std::vector>> expected = { + {{"one", "two", "three", ""}}, + {{"a ", " here"}}, + {{"h", ""}}, + {{"", "nights"}}, + {{"", "r", "ng"}}}; + testSplitEncodings(inputs, expected); } } // namespace