Skip to content

Commit

Permalink
Fix split function
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 22, 2024
1 parent ffd136f commit 6cc0b4f
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 100 deletions.
10 changes: 8 additions & 2 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,15 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT rtrim('kr', 'spark'); -- "spa"

.. spark:function:: split(string, delimiter) -> array(string)
.. spark:function:: split(string, delimiter, limit) -> array(string)
Splits ``string`` on ``delimiter`` and returns an array. ::
Splits ``string`` around occurrences that match ``delimiter`` and returns an array
with a length of at most ``limit``. ``delimiter`` is a string representing a regular
expression. ``limit`` is an integer which controls the number of times the regex is
applied. When ``limit`` > 0, the resulting array's length will not be more than
``limit``, and the resulting array's last entry will contain all input beyond the
last matched regex. When ``limit`` <= 0, ``regex`` will be applied as many times as
possible, and the resulting array can be of any size. ::

SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""]
SELECT split('one', ''); -- ["o", "n", "e", ""]
Expand Down
246 changes: 176 additions & 70 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <re2/re2.h>
#include <utility>

#include "velox/expression/VectorFunction.h"
Expand All @@ -22,16 +23,9 @@
namespace facebook::velox::functions::sparksql {
namespace {

/// This class only implements the basic split version in which the pattern is a
/// single character
class SplitCharacter final : public exec::VectorFunction {
class Split final : public exec::VectorFunction {
public:
explicit SplitCharacter(const char pattern) : pattern_{pattern} {
static constexpr std::string_view kRegexChars = ".$|()[{^?*+\\";
VELOX_CHECK(
kRegexChars.find(pattern) == std::string::npos,
"This version of split supports single-length non-regex patterns");
}
Split() {}

void apply(
const SelectivityVector& rows,
Expand All @@ -45,23 +39,66 @@ class SplitCharacter final : public exec::VectorFunction {
exec::VectorWriter<Array<Varchar>> resultWriter;
resultWriter.init(*result->as<ArrayVector>());

rows.applyToSelected([&](vector_size_t row) {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();

const StringView& current = input->valueAt<StringView>(row);
const char* pos = current.begin();
const char* end = pos + current.size();
const char* delim;
do {
delim = std::find(pos, end, pattern_);
arrayWriter.add_item().setNoCopy(StringView(pos, delim - pos));
pos = delim + 1; // Skip past delim.
} while (delim != end);

resultWriter.commit();
});

// Fast path for pattern and limit being constant.
if (args[1]->isConstantEncoding() && args[2]->isConstantEncoding()) {
// Adds brackets to the input pattern for sub-pattern extraction.
const auto pattern =
args[1]->asUnchecked<ConstantVector<StringView>>()->valueAt(0);
const auto limit =
args[2]->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
if (pattern.size() == 0) {
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
});
}
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
if (limit > 0) {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
});
} else {
rows.applyToSelected([&](vector_size_t row) {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
});
}
}
} else {
exec::LocalDecodedVector patterns(context, *args[1], rows);
exec::LocalDecodedVector limits(context, *args[2], rows);

rows.applyToSelected([&](vector_size_t row) {
const auto pattern = patterns->valueAt<StringView>(row);
const auto limit = limits->valueAt<int32_t>(row);
if (pattern.size() == 0) {
if (limit > 0) {
splitEmptyPattern<true>(
input->valueAt<StringView>(row), row, resultWriter, limit);
} else {
splitEmptyPattern<false>(
input->valueAt<StringView>(row), row, resultWriter);
}
} else {
const auto re = re2::RE2("(" + pattern.str() + ")");
if (limit > 0) {
splitAndWrite<true>(
input->valueAt<StringView>(row), re, row, resultWriter, limit);
} else {
splitAndWrite<false>(
input->valueAt<StringView>(row), re, row, resultWriter);
}
}
});
}
resultWriter.finish();

// Reference the input StringBuffers since we did not deep copy above.
Expand All @@ -72,65 +109,134 @@ class SplitCharacter final : public exec::VectorFunction {
}

private:
const char pattern_;
};

/// This class will be updated in the future as we support more variants of
/// split
class Split final : public exec::VectorFunction {
public:
Split() {}
// When pattern is empty, split each character.
template <bool limited>
void splitEmptyPattern(
const StringView current,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
if (current.size() == 0) {
arrayWriter.add_item().setNoCopy(StringView());
resultWriter.commit();
return;
}

const char* const begin = current.begin();
const char* const end = current.end();
const char* pos = begin;
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
while (pos != end && pos - begin < limit - 1) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
if (pos == end) {
arrayWriter.add_item().setNoCopy(StringView());
}
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
}
} else {
while (pos != end) {
arrayWriter.add_item().setNoCopy(StringView(pos, 1));
pos += 1;
if (pos == end) {
arrayWriter.add_item().setNoCopy(StringView());
}
}
}
resultWriter.commit();
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
auto delimiterVector = args[1]->as<ConstantVector<StringView>>();
VELOX_CHECK(
delimiterVector, "Split function supports only constant delimiter");
auto patternString = args[1]->as<ConstantVector<StringView>>()->valueAt(0);
VELOX_CHECK_EQ(
patternString.size(),
1,
"split only supports only single-character pattern");
char pattern = patternString.data()[0];
SplitCharacter splitCharacter(pattern);
splitCharacter.apply(rows, args, nullptr, context, result);
// Split with a non-empty pattern.
template <bool limited>
void splitAndWrite(
const StringView current,
const re2::RE2& re,
vector_size_t row,
exec::VectorWriter<Array<Varchar>>& resultWriter,
uint32_t limit = 0) const {
resultWriter.setOffset(row);
auto& arrayWriter = resultWriter.current();
if (current.size() == 0) {
arrayWriter.add_item().setNoCopy(StringView());
resultWriter.commit();
return;
}

const char* pos = current.begin();
const char* const end = current.end();
if constexpr (limited) {
VELOX_DCHECK_GT(limit, 0);
uint32_t numPieces = 0;
while (pos != end && numPieces < limit - 1) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
numPieces += 1;
if (piece.end() == end) {
// When the found delimiter is at the end of input string, keeps
// one empty piece of string.
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
}
if (pos < end) {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
}
} else {
while (pos != end) {
if (re2::StringPiece piece; re2::RE2::PartialMatch(
re2::StringPiece(pos, end - pos), re, &piece)) {
arrayWriter.add_item().setNoCopy(StringView(pos, piece.data() - pos));
if (piece.end() == end) {
arrayWriter.add_item().setNoCopy(StringView());
}
pos = piece.end();
} else {
arrayWriter.add_item().setNoCopy(StringView(pos, end - pos));
pos = end;
}
}
}
resultWriter.commit();
}
};

/// The function returns specialized version of split based on the constant
/// inputs.
/// \param inputArgs the inputs types (VARCHAR, VARCHAR, int64) and constant
/// values (if provided).
/// Returns split function.
/// @param inputArgs the inputs types (VARCHAR, VARCHAR, int32).
std::shared_ptr<exec::VectorFunction> createSplit(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
BaseVector* constantPattern = inputArgs[1].constantValue.get();

if (inputArgs.size() > 3 || inputArgs[0].type->isVarchar() ||
inputArgs[1].type->isVarchar() || (constantPattern == nullptr)) {
return std::make_shared<Split>();
}
auto pattern = constantPattern->as<ConstantVector<StringView>>()->valueAt(0);
if (pattern.size() != 1) {
return std::make_shared<Split>();
}
char charPattern = pattern.data()[0];
// TODO: Add support for zero-length pattern, 2-character pattern
// TODO: add support for general regex pattern using R2
return std::make_shared<SplitCharacter>(charPattern);
VELOX_USER_CHECK_EQ(
inputArgs.size(), 3, "Three arguments are required for split function.");
VELOX_USER_CHECK(
inputArgs[0].type->isVarchar(),
"The first argument should be of varchar type.");
VELOX_USER_CHECK(
inputArgs[1].type->isVarchar(),
"The second argument should be of varchar type.");
VELOX_USER_CHECK(
inputArgs[2].type->kind() == TypeKind::INTEGER,
"The third argument should be of integer type.");
return std::make_shared<Split>();
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// varchar, varchar -> array(varchar)
return {exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.constantArgumentType("varchar")
.argumentType("varchar")
.argumentType("integer")
.build()};
}

Expand Down
Loading

0 comments on commit 6cc0b4f

Please sign in to comment.