diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 20ac273ee89c5..354fc0cec4654 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -178,19 +178,22 @@ Unless specified otherwise, all functions return NULL if at least one of the arg SELECT rtrim('kr', 'spark'); -- "spa" -.. spark:function:: split(string, delimiter, limit) -> array(string) +.. spark:function:: split(string, delimiter[, limit]) -> array(string) Splits ``string`` around occurrences that match ``delimiter`` and returns an array with a length of at most ``limit``. ``delimiter`` is a string representing a regular expression. ``limit`` is an integer which controls the number of times the regex is - applied. When ``limit`` > 0, the resulting array's length will not be more than - ``limit``, and the resulting array's last entry will contain all input beyond the - last matched regex. When ``limit`` <= 0, ``regex`` will be applied as many times as - possible, and the resulting array can be of any size. :: + applied. By default, ``limit`` is -1. When ``limit`` > 0, the resulting array's + length will not be more than ``limit``, and the resulting array's last entry will + contain all input beyond the last matched regex. When ``limit`` <= 0, ``regex`` will + be applied as many times as possible, and the resulting array can be of any size. :: SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""] + SELECT split('oneAtwoBthreeC', '[ABC]', 2); -- ["one","twoBthreeC"] SELECT split('one', ''); -- ["o", "n", "e", ""] SELECT split('one', '1'); -- ["one"] + SELECT split('abcd', ''); -- ["a", "b", "c", "d"] + SELECT split('abcd', '', 3); -- ["a", "b", "c"] .. spark:function:: split(string, delimiter, limit) -> array(string) :noindex: diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index bd439fd7cb09e..b8ec5c3f80400 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -36,39 +36,37 @@ class Split final : public exec::VectorFunction { BaseVector::ensureWritable(rows, ARRAY(VARCHAR()), context.pool(), result); exec::VectorWriter> resultWriter; resultWriter.init(*result->as()); + int32_t limit = -1; - // Fast path for pattern and limit being constant. - if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) { + // Fast path for pattern and limit being constants. + if (args[1]->isConstantEncoding() && + (args.size() == 2 || args[2]->isConstantEncoding())) { // Adds brackets to the input pattern for sub-pattern extraction. const auto pattern = args[1]->asUnchecked>()->valueAt(0); - const auto limit = - args[2]->asUnchecked>()->valueAt(0); + if (args.size() > 2) { + limit = args[2]->asUnchecked>()->valueAt(0); + } + const auto positiveLimit = + limit > 0 ? limit : std::numeric_limits::max(); if (pattern.size() == 0) { - if (limit > 0) { - rows.applyToSelected([&](vector_size_t row) { - splitEmptyPattern( - input->valueAt(row), row, resultWriter, limit); - }); - } else { - rows.applyToSelected([&](vector_size_t row) { - splitEmptyPattern( - input->valueAt(row), row, resultWriter); - }); - } + rows.applyToSelected([&](vector_size_t row) { + splitEmptyPattern( + input->valueAt(row), + row, + resultWriter, + positiveLimit); + }); } else { const auto re = re2::RE2("(" + pattern.str() + ")"); - if (limit > 0) { - rows.applyToSelected([&](vector_size_t row) { - splitAndWrite( - input->valueAt(row), re, row, resultWriter, limit); - }); - } else { - rows.applyToSelected([&](vector_size_t row) { - splitAndWrite( - input->valueAt(row), re, row, resultWriter); - }); - } + rows.applyToSelected([&](vector_size_t row) { + splitAndWrite( + input->valueAt(row), + re, + row, + resultWriter, + positiveLimit); + }); } } else { exec::LocalDecodedVector patterns(context, *args[1], rows); @@ -76,24 +74,22 @@ class Split final : public exec::VectorFunction { rows.applyToSelected([&](vector_size_t row) { const auto pattern = patterns->valueAt(row); - const auto limit = limits->valueAt(row); + limit = limits->valueAt(row); + const auto positiveLimit = + limit > 0 ? limit : std::numeric_limits::max(); if (pattern.size() == 0) { - if (limit > 0) { - splitEmptyPattern( - input->valueAt(row), row, resultWriter, limit); - } else { - splitEmptyPattern( - input->valueAt(row), row, resultWriter); - } + splitEmptyPattern( + input->valueAt(row), + row, + resultWriter, + positiveLimit); } else { - const auto re = re2::RE2("(" + pattern.str() + ")"); - if (limit > 0) { - splitAndWrite( - input->valueAt(row), re, row, resultWriter, limit); - } else { - splitAndWrite( - input->valueAt(row), re, row, resultWriter); - } + splitAndWrite( + input->valueAt(row), + re2::RE2("(" + pattern.str() + ")"), + row, + resultWriter, + positiveLimit); } }); } @@ -108,12 +104,11 @@ class Split final : public exec::VectorFunction { private: // When pattern is empty, split each character. - template void splitEmptyPattern( const StringView current, vector_size_t row, exec::VectorWriter>& resultWriter, - uint32_t limit = 0) const { + uint32_t limit) const { resultWriter.setOffset(row); auto& arrayWriter = resultWriter.current(); if (current.size() == 0) { @@ -125,32 +120,20 @@ class Split final : public exec::VectorFunction { const char* const begin = current.begin(); const char* const end = current.end(); const char* pos = begin; - if constexpr (limited) { - VELOX_DCHECK_GT(limit, 0); - while (pos != end && pos - begin < limit - 1) { - arrayWriter.add_item().setNoCopy(StringView(pos, 1)); - pos += 1; - } - if (pos < end) { - arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); - } - } else { - while (pos != end) { - arrayWriter.add_item().setNoCopy(StringView(pos, 1)); - pos += 1; - } + while (pos < end && pos < limit + begin) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; } resultWriter.commit(); } // Split with a non-empty pattern. - template void splitAndWrite( const StringView current, const re2::RE2& re, vector_size_t row, exec::VectorWriter>& resultWriter, - uint32_t limit = 0) const { + uint32_t limit) const { resultWriter.setOffset(row); auto& arrayWriter = resultWriter.current(); if (current.size() == 0) { @@ -161,42 +144,26 @@ class Split final : public exec::VectorFunction { const char* pos = current.begin(); const char* const end = current.end(); - if constexpr (limited) { - VELOX_DCHECK_GT(limit, 0); - uint32_t numPieces = 0; - while (pos != end && numPieces < limit - 1) { - if (re2::StringPiece piece; re2::RE2::PartialMatch( - re2::StringPiece(pos, end - pos), re, &piece)) { - arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); - numPieces += 1; - if (piece.end() == end) { - // When the found delimiter is at the end of input string, keeps - // one empty piece of string. - arrayWriter.add_item().setNoCopy(StringView()); - } - pos = piece.end(); - } else { - arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); - pos = end; + VELOX_DCHECK_GT(limit, 0); + uint32_t numPieces = 0; + while (pos < end && numPieces < limit - 1) { + if (re2::StringPiece piece; re2::RE2::PartialMatch( + re2::StringPiece(pos, end - pos), re, &piece)) { + arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); + numPieces += 1; + if (piece.end() == end) { + // When the found delimiter is at the end of input string, keeps + // one empty piece of string. + arrayWriter.add_item().setNoCopy(StringView()); } - } - if (pos < end) { + pos = piece.end(); + } else { arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + pos = end; } - } else { - while (pos != end) { - if (re2::StringPiece piece; re2::RE2::PartialMatch( - re2::StringPiece(pos, end - pos), re, &piece)) { - arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); - if (piece.end() == end) { - arrayWriter.add_item().setNoCopy(StringView()); - } - pos = piece.end(); - } else { - arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); - pos = end; - } - } + } + if (pos < end) { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); } resultWriter.commit(); } @@ -208,28 +175,37 @@ std::shared_ptr createSplit( const std::string& /*name*/, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { - VELOX_USER_CHECK_EQ( - inputArgs.size(), 3, "Three arguments are required for split function."); + VELOX_USER_CHECK( + inputArgs.size() == 2 || inputArgs.size() == 3, + "Two or three arguments are required for split function."); VELOX_USER_CHECK( inputArgs[0].type->isVarchar(), "The first argument should be of varchar type."); VELOX_USER_CHECK( inputArgs[1].type->isVarchar(), "The second argument should be of varchar type."); - VELOX_USER_CHECK( - inputArgs[2].type->kind() == TypeKind::INTEGER, - "The third argument should be of integer type."); + if (inputArgs.size() > 2) { + VELOX_USER_CHECK( + inputArgs[2].type->kind() == TypeKind::INTEGER, + "The third argument should be of integer type."); + } return std::make_shared(); } std::vector> signatures() { // varchar, varchar -> array(varchar) - return {exec::FunctionSignatureBuilder() - .returnType("array(varchar)") - .argumentType("varchar") - .argumentType("varchar") - .argumentType("integer") - .build()}; + return { + exec::FunctionSignatureBuilder() + .returnType("array(varchar)") + .argumentType("varchar") + .argumentType("varchar") + .build(), + exec::FunctionSignatureBuilder() + .returnType("array(varchar)") + .argumentType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build()}; } } // namespace diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index a2abb8cfbcddb..2450c5845018c 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -28,7 +28,7 @@ class SplitTest : public SparkFunctionBaseTest { const std::vector>& input, std::optional pattern, const std::vector>>& output, - int32_t limit = -1); + std::optional limit = std::nullopt); void testSplitEncodings( const std::vector& inputs, @@ -57,7 +57,7 @@ void SplitTest::testSplit( const std::vector>& input, std::optional pattern, const std::vector>>& output, - int32_t limit) { + std::optional limit) { auto valueAt = [&input](vector_size_t row) { return input[row] ? StringView(*input[row]) : StringView(); }; @@ -74,8 +74,9 @@ void SplitTest::testSplit( std::string patternString = pattern.has_value() ? std::string(", '") + pattern.value() + "'" : ", ''"; - const std::string limitString = - ", '" + std::to_string(limit) + "'::INTEGER"; + const std::string limitString = limit.has_value() + ? ", '" + std::to_string(limit.value()) + "'::INTEGER" + : ""; std::string expressionString = std::string("split(c0") + patternString + limitString + ")"; return evaluate(expressionString, rowVector); @@ -143,10 +144,13 @@ TEST_F(SplitTest, zeroLengthPattern) { {{{"a", "b", "c", "d", "e", "f", "g"}}, {{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")"}}, {{""}}}); + + // The result does not include remaining string when limit is smaller than the + // string size. testSplit( - {"abcdefg", "abc:+%/n?(^)", ""}, + {"abcdefg", "ab:c+%/n?(^)", ""}, std::nullopt, - {{{"a", "b", "cdefg"}}, {{"a", "b", "c:+%/n?(^)"}}, {{""}}}, + {{{"a", "b", "c"}}, {{"a", "b", ":"}}, {{""}}}, 3); testSplit( {"abcdefg", "abc:+%/n?(^)", ""}, @@ -203,7 +207,7 @@ TEST_F(SplitTest, encodings) { limits = makeFlatVector({3, 3, 2, 1, 5, 2, 1, 1, 2, 2}); expected = { - {{"a", "b", "cdef"}}, + {{"a", "b", "c"}}, {{"one", "two", "threeC"}}, {{"aa", "bb3cc"}}, {{"hello"}},