From 2768da2076a4869e17d0a672545fce51b4d38be4 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Wed, 21 Feb 2024 13:41:56 +0800 Subject: [PATCH] Fix split function --- velox/docs/functions/spark/string.rst | 10 +- velox/functions/sparksql/SplitFunctions.cpp | 246 +++++++++++++----- .../sparksql/tests/SplitFunctionsTest.cpp | 143 ++++++++-- 3 files changed, 299 insertions(+), 100 deletions(-) diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 489e8f89238ba..20ac273ee89c5 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -178,9 +178,15 @@ Unless specified otherwise, all functions return NULL if at least one of the arg SELECT rtrim('kr', 'spark'); -- "spa" -.. spark:function:: split(string, delimiter) -> array(string) +.. spark:function:: split(string, delimiter, limit) -> array(string) - Splits ``string`` on ``delimiter`` and returns an array. :: + Splits ``string`` around occurrences that match ``delimiter`` and returns an array + with a length of at most ``limit``. ``delimiter`` is a string representing a regular + expression. ``limit`` is an integer which controls the number of times the regex is + applied. When ``limit`` > 0, the resulting array's length will not be more than + ``limit``, and the resulting array's last entry will contain all input beyond the + last matched regex. When ``limit`` <= 0, ``regex`` will be applied as many times as + possible, and the resulting array can be of any size. :: SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""] SELECT split('one', ''); -- ["o", "n", "e", ""] diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index 4d092e6928373..3f9f0ac3cd411 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "velox/expression/VectorFunction.h" @@ -22,16 +23,9 @@ namespace facebook::velox::functions::sparksql { namespace { -/// This class only implements the basic split version in which the pattern is a -/// single character -class SplitCharacter final : public exec::VectorFunction { +class Split final : public exec::VectorFunction { public: - explicit SplitCharacter(const char pattern) : pattern_{pattern} { - static constexpr std::string_view kRegexChars = ".$|()[{^?*+\\"; - VELOX_CHECK( - kRegexChars.find(pattern) == std::string::npos, - "This version of split supports single-length non-regex patterns"); - } + Split() {} void apply( const SelectivityVector& rows, @@ -45,23 +39,66 @@ class SplitCharacter final : public exec::VectorFunction { exec::VectorWriter> resultWriter; resultWriter.init(*result->as()); - rows.applyToSelected([&](vector_size_t row) { - resultWriter.setOffset(row); - auto& arrayWriter = resultWriter.current(); - - const StringView& current = input->valueAt(row); - const char* pos = current.begin(); - const char* end = pos + current.size(); - const char* delim; - do { - delim = std::find(pos, end, pattern_); - arrayWriter.add_item().setNoCopy(StringView(pos, delim - pos)); - pos = delim + 1; // Skip past delim. - } while (delim != end); - - resultWriter.commit(); - }); - + // Fast path for pattern and limit being constant. + if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) { + // Adds brackets to the input pattern for sub-pattern extraction. + const auto pattern = + args[1]->asUnchecked>()->valueAt(0); + const auto limit = + args[2]->asUnchecked>()->valueAt(0); + if (pattern.size() == 0) { + if (limit > 0) { + rows.applyToSelected([&](vector_size_t row) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter, limit); + }); + } else { + rows.applyToSelected([&](vector_size_t row) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter); + }); + } + } else { + const auto re = re2::RE2("(" + pattern.str() + ")"); + if (limit > 0) { + rows.applyToSelected([&](vector_size_t row) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter, limit); + }); + } else { + rows.applyToSelected([&](vector_size_t row) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter); + }); + } + } + } else { + exec::LocalDecodedVector patterns(context, *args[1], rows); + exec::LocalDecodedVector limits(context, *args[2], rows); + + rows.applyToSelected([&](vector_size_t row) { + const auto pattern = patterns->valueAt(row); + const auto limit = limits->valueAt(row); + if (pattern.size() == 0) { + if (limit > 0) { + splitEmptyPattern( + input->valueAt(row), row, resultWriter, limit); + } else { + splitEmptyPattern( + input->valueAt(row), row, resultWriter); + } + } else { + const auto re = re2::RE2("(" + pattern.str() + ")"); + if (limit > 0) { + splitAndWrite( + input->valueAt(row), re, row, resultWriter, limit); + } else { + splitAndWrite( + input->valueAt(row), re, row, resultWriter); + } + } + }); + } resultWriter.finish(); // Reference the input StringBuffers since we did not deep copy above. @@ -72,57 +109,125 @@ class SplitCharacter final : public exec::VectorFunction { } private: - const char pattern_; -}; - -/// This class will be updated in the future as we support more variants of -/// split -class Split final : public exec::VectorFunction { - public: - Split() {} + // When pattern is empty, split each character. + template + void splitEmptyPattern( + const StringView current, + vector_size_t row, + exec::VectorWriter>& resultWriter, + uint32_t limit = 0) const { + resultWriter.setOffset(row); + auto& arrayWriter = resultWriter.current(); + if (current.size() == 0) { + arrayWriter.add_item().setNoCopy(StringView()); + resultWriter.commit(); + return; + } + + const char* const begin = current.begin(); + const char* const end = current.end(); + const char* pos = begin; + if constexpr (limited) { + VELOX_DCHECK_GT(limit, 0); + while (pos != end && pos - begin < limit - 1) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; + if (pos == end) { + arrayWriter.add_item().setNoCopy(StringView()); + } + } + if (pos < end) { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + } + } else { + while (pos != end) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; + if (pos == end) { + arrayWriter.add_item().setNoCopy(StringView()); + } + } + } + resultWriter.commit(); + } - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /* outputType */, - exec::EvalCtx& context, - VectorPtr& result) const override { - auto delimiterVector = args[1]->as>(); - VELOX_CHECK( - delimiterVector, "Split function supports only constant delimiter"); - auto patternString = args[1]->as>()->valueAt(0); - VELOX_CHECK_EQ( - patternString.size(), - 1, - "split only supports only single-character pattern"); - char pattern = patternString.data()[0]; - SplitCharacter splitCharacter(pattern); - splitCharacter.apply(rows, args, nullptr, context, result); + // Split with a non-empty pattern. + template + void splitAndWrite( + const StringView current, + const re2::RE2& re, + vector_size_t row, + exec::VectorWriter>& resultWriter, + uint32_t limit = 0) const { + resultWriter.setOffset(row); + auto& arrayWriter = resultWriter.current(); + if (current.size() == 0) { + arrayWriter.add_item().setNoCopy(StringView()); + resultWriter.commit(); + return; + } + + const char* pos = current.begin(); + const char* const end = current.end(); + if constexpr (limited) { + VELOX_DCHECK_GT(limit, 0); + uint32_t numPieces = 0; + while (pos != end && numPieces < limit - 1) { + if (re2::StringPiece piece; re2::RE2::PartialMatch( + re2::StringPiece(pos, end - pos), re, &piece)) { + arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); + numPieces += 1; + if (piece.end() == end) { + // When the found delimiter is at the end of input string, keeps + // one empty piece of string. + arrayWriter.add_item().setNoCopy(StringView()); + } + pos = piece.end(); + } else { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + pos = end; + } + } + if (pos < end) { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + } + } else { + while (pos != end) { + if (re2::StringPiece piece; re2::RE2::PartialMatch( + re2::StringPiece(pos, end - pos), re, &piece)) { + arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos)); + if (piece.end() == end) { + arrayWriter.add_item().setNoCopy(StringView()); + } + pos = piece.end(); + } else { + arrayWriter.add_item().setNoCopy(StringView(pos, end - pos)); + pos = end; + } + } + } + resultWriter.commit(); } }; -/// The function returns specialized version of split based on the constant -/// inputs. -/// \param inputArgs the inputs types (VARCHAR, VARCHAR, int64) and constant -/// values (if provided). +/// Returns split function. +/// @param inputArgs the inputs types (VARCHAR, VARCHAR, int32). std::shared_ptr createSplit( const std::string& /*name*/, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { - BaseVector* constantPattern = inputArgs[1].constantValue.get(); - - if (inputArgs.size() > 3 || inputArgs[0].type->isVarchar() || - inputArgs[1].type->isVarchar() || (constantPattern == nullptr)) { - return std::make_shared(); - } - auto pattern = constantPattern->as>()->valueAt(0); - if (pattern.size() != 1) { - return std::make_shared(); - } - char charPattern = pattern.data()[0]; - // TODO: Add support for zero-length pattern, 2-character pattern - // TODO: add support for general regex pattern using R2 - return std::make_shared(charPattern); + VELOX_USER_CHECK_EQ( + inputArgs.size(), 3, "Three arguments are required for split function."); + VELOX_USER_CHECK( + inputArgs[0].type->isVarchar(), + "The first argument should be of varchar type."); + VELOX_USER_CHECK( + inputArgs[1].type->isVarchar(), + "The second argument should be of varchar type."); + VELOX_USER_CHECK( + inputArgs[2].type->kind() == TypeKind::INTEGER, + "The third argument should be of integer type."); + return std::make_shared(); } std::vector> signatures() { @@ -130,7 +235,8 @@ std::vector> signatures() { return {exec::FunctionSignatureBuilder() .returnType("array(varchar)") .argumentType("varchar") - .constantArgumentType("varchar") + .argumentType("varchar") + .argumentType("integer") .build()}; } diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index 8928849a44ce2..c8d49c0163ddf 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -24,16 +24,40 @@ namespace { class SplitTest : public SparkFunctionBaseTest { protected: - void testSplitCharacter( + void testSplit( const std::vector>& input, - std::optional pattern, + std::optional pattern, + const std::vector>>& output, + int32_t limit = -1); + + void testSplitEncodings( + const std::vector& inputs, const std::vector>>& output); + + ArrayVectorPtr toArrayVector( + const std::vector>>& vector); }; -void SplitTest::testSplitCharacter( +ArrayVectorPtr SplitTest::toArrayVector( + const std::vector>>& vector) { + // Creating vectors for output string vectors + auto sizeAt = [&vector](vector_size_t row) { + return vector[row] ? vector[row]->size() : 0; + }; + auto valueAt = [&vector](vector_size_t row, vector_size_t idx) { + return vector[row] ? StringView(vector[row]->at(idx)) : StringView(""); + }; + auto nullAt = [&vector](vector_size_t row) { + return !vector[row].has_value(); + }; + return makeArrayVector(vector.size(), sizeAt, valueAt, nullAt); +} + +void SplitTest::testSplit( const std::vector>& input, - std::optional pattern, - const std::vector>>& output) { + std::optional pattern, + const std::vector>>& output, + int32_t limit) { auto valueAt = [&input](vector_size_t row) { return input[row] ? StringView(*input[row]) : StringView(); }; @@ -47,34 +71,39 @@ void SplitTest::testSplitCharacter( auto rowVector = makeRowVector({inputString}); // Evaluating the function for each input and seed - std::string patternString = - pattern.has_value() ? std::string(", '") + pattern.value() + "'" : ""; + std::string patternString = pattern.has_value() + ? std::string(", '") + pattern.value() + "'" + : ", ''"; + const std::string limitString = + ", '" + std::to_string(limit) + "'::INTEGER"; std::string expressionString = - std::string("split(c0") + patternString + ")"; + std::string("split(c0") + patternString + limitString + ")"; return evaluate(expressionString, rowVector); }(); - // Creating vectors for output string vectors - auto sizeAtOutput = [&output](vector_size_t row) { - return output[row] ? output[row]->size() : 0; - }; - auto valueAtOutput = [&output](vector_size_t row, vector_size_t idx) { - return output[row] ? StringView(output[row]->at(idx)) : StringView(""); - }; - auto nullAtOutput = [&output](vector_size_t row) { - return !output[row].has_value(); - }; - auto expectedResult = makeArrayVector( - output.size(), sizeAtOutput, valueAtOutput, nullAtOutput); + const auto expectedResult = toArrayVector(output); // Checking the results assertEqualVectors(expectedResult, result); } +void SplitTest::testSplitEncodings( + const std::vector& inputs, + const std::vector>>& output) { + const auto expected = toArrayVector(output); + std::vector inputExprs = { + std::make_shared(inputs[0]->type(), "c0"), + std::make_shared(inputs[1]->type(), "c1"), + std::make_shared(inputs[2]->type(), "c2")}; + const auto expr = std::make_shared( + expected->type(), std::move(inputExprs), "split"); + testEncodings(expr, inputs, expected); +} + TEST_F(SplitTest, reallocationAndCornerCases) { - testSplitCharacter( + testSplit( {"boo:and:foo", "abcfd", "abcfd:", "", ":ab::cfd::::"}, - ':', + ":", {{{"boo", "and", "foo"}}, {{"abcfd"}}, {{"abcfd", ""}}, @@ -83,9 +112,9 @@ TEST_F(SplitTest, reallocationAndCornerCases) { } TEST_F(SplitTest, nulls) { - testSplitCharacter( + testSplit( {std::nullopt, "abcfd", "abcfd:", std::nullopt, ":ab::cfd::::"}, - ':', + ":", {{std::nullopt}, {{"abcfd"}}, {{"abcfd", ""}}, @@ -94,16 +123,74 @@ TEST_F(SplitTest, nulls) { } TEST_F(SplitTest, defaultArguments) { - testSplitCharacter( - {"boo:and:foo", "abcfd"}, ':', {{{"boo", "and", "foo"}}, {{"abcfd"}}}); + testSplit( + {"boo:and:foo", "abcfd"}, ":", {{{"boo", "and", "foo"}}, {{"abcfd"}}}); } TEST_F(SplitTest, longStrings) { - testSplitCharacter( + testSplit( {"abcdefghijklkmnopqrstuvwxyz"}, - ',', + ",", {{{"abcdefghijklkmnopqrstuvwxyz"}}}); } +TEST_F(SplitTest, zeroLengthPattern) { + testSplit( + {"abcdefg", "abc:+%/n?(^)", ""}, + std::nullopt, + {{{"a", "b", "c", "d", "e", "f", "g", ""}}, + {{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")", ""}}, + {{""}}}); + testSplit( + {"abcdefg", "abc:+%/n?(^)", ""}, + std::nullopt, + {{{"a", "b", "cdefg"}}, {{"a", "b", "c:+%/n?(^)"}}, {{""}}}, + 3); + testSplit( + {"abcdefg", "abc:+%/n?(^)", ""}, + std::nullopt, + {{{"a", "b", "c", "d", "e", "f", "g", ""}}, + {{"a", "b", "c", ":", "+", "%", "/", "n", "?", "(", "^", ")", ""}}, + {{""}}}, + 20); +} + +TEST_F(SplitTest, encodings) { + auto strings = makeFlatVector( + {"abcdef", + "oneAtwoBthreeC", + "a chrisr:9000 here", + "hello", + "1001 nights", + "morning", + "", + ""}); + auto patterns = makeFlatVector( + {"", "[ABC]", "((\\w+):([0-9]+))", "e.*o", "(\\d+)", "(mo)|ni", ":", ""}); + auto limits = makeFlatVector({0, -1, -1, 0, -1, -2, -1, -4}); + std::vector>> expected = { + {{"a", "b", "c", "d", "e", "f", ""}}, + {{"one", "two", "three", ""}}, + {{"a ", " here"}}, + {{"h", ""}}, + {{"", " nights"}}, + {{"", "r", "ng"}}, + {{""}}, + {{""}}}; + testSplitEncodings({strings, patterns, limits}, expected); + + limits = makeFlatVector({3, 3, 3, 1, 1, 2, 1, 1}); + expected = { + {{"a", "b", "cdef"}}, + {{"one", "two", "threeC"}}, + {{"a ", " here"}}, + {{"hello"}}, + {{"1001 nights"}}, + {{"", "rning"}}, + {{""}}, + {{""}}}; + testSplitEncodings({strings, patterns, limits}, expected); +} + } // namespace } // namespace facebook::velox::functions::sparksql::test