Skip to content

Commit

Permalink
Use simple function api
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Dec 5, 2023
1 parent f435675 commit 03e3e49
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 159 deletions.
9 changes: 8 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "velox/functions/sparksql/RegisterCompare.h"
#include "velox/functions/sparksql/Size.h"
#include "velox/functions/sparksql/String.h"
#include "velox/functions/sparksql/StringToMap.h"
#include "velox/functions/sparksql/UnscaledValueFunction.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -144,6 +145,13 @@ void registerFunctions(const std::string& prefix) {
int32_t,
int32_t>({prefix + "overlay"});

registerFunction<
sparksql::StringToMapFunction,
Map<Varchar, Varchar>,
Varchar,
Varchar,
Varchar>({prefix + "str_to_map"});

registerFunction<sparksql::LeftFunction, Varchar, Varchar, int32_t>(
{prefix + "left"});

Expand All @@ -153,7 +161,6 @@ void registerFunctions(const std::string& prefix) {
prefix + "length", lengthSignatures(), makeLength);
registerFunction<SubstringIndexFunction, Varchar, Varchar, Varchar, int32_t>(
{prefix + "substring_index"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map");

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
117 changes: 0 additions & 117 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,128 +134,11 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> splitSignatures() {
.constantArgumentType("varchar")
.build()};
}

class StringToMap final : public exec::VectorFunction {
public:
StringToMap(char entryDelim, char keyValueDelim)
: entryDelim_(entryDelim), keyValueDelim_(keyValueDelim) {}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
exec::DecodedArgs decodedArgs(rows, args, context);
DecodedVector* inputString = decodedArgs.at(0);

BaseVector::ensureWritable(
rows, MAP(VARCHAR(), VARCHAR()), context.pool(), result);
exec::VectorWriter<Map<Varchar, Varchar>> resultWriter;
resultWriter.init(*result->as<MapVector>());

context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
folly::F14FastSet<StringView> keys;
resultWriter.setOffset(row);
auto& mapWriter = resultWriter.current();

const StringView& current = inputString->valueAt<StringView>(row);
const char* pos = current.begin();
const char* end = pos + current.size();
const char* entryEnd;
const char* keyEnd;
do {
entryEnd = std::find(pos, end, entryDelim_);
keyEnd = std::find(pos, entryEnd, keyValueDelim_);
const auto key = StringView(pos, keyEnd - pos);
VELOX_USER_CHECK(
keys.insert(key).second,
"Duplicated keys ('{}') are not allowed.",
key);
if (keyEnd == entryEnd) {
mapWriter.add_null().append(key);
} else {
auto [keyWriter, valueWriter] = mapWriter.add_item();
keyWriter.setNoCopy(key);
valueWriter.setNoCopy(StringView(keyEnd + 1, entryEnd - keyEnd - 1));
}
pos = entryEnd + 1; // Skip past delim.
} while (entryEnd != end);

resultWriter.commit();
});

resultWriter.finish();

// Reuses input buffer.
result->as<MapVector>()
->mapKeys()
->as<FlatVector<StringView>>()
->acquireSharedStringBuffers(inputString->base());
result->as<MapVector>()
->mapValues()
->as<FlatVector<StringView>>()
->acquireSharedStringBuffers(inputString->base());
}

private:
const char entryDelim_;
const char keyValueDelim_;
};

/// Currently only supports single-character entryDelim & keyValueDelim and
/// these two delimiters must be constant.
std::shared_ptr<exec::VectorFunction> createStringToMap(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_EQ(
inputArgs.size(), 3, "Expects 3 arguments for StringToMap.");

auto getDelimiter =
[](exec::VectorFunctionArg inputArg) -> std::optional<char> {
BaseVector* constantVector = inputArg.constantValue.get();
VELOX_USER_CHECK_NOT_NULL(
constantVector,
"StringToMap requires constant entry/key-value delimiter.");
const auto constantStringView =
constantVector->as<ConstantVector<StringView>>();
if (constantStringView->isNullAt(0)) {
return std::nullopt;
}
VELOX_USER_CHECK(
constantStringView->valueAt(0).size() == 1,
"StringToMap only supports single-character entry/key-value delimiter.");
return constantStringView->valueAt(0).data()[0];
};
std::optional<char> entryDelim = getDelimiter(inputArgs[1]);
std::optional<char> keyValueDelim = getDelimiter(inputArgs[2]);
// As isDefaultNullBehavior = true.
if (!entryDelim.has_value() || !keyValueDelim.has_value()) {
return std::make_shared<exec::ApplyNeverCalled>();
}
return std::make_shared<StringToMap>(*entryDelim, *keyValueDelim);
}

// varchar, varchar, varchar -> map(varchar, varchar).
std::vector<std::shared_ptr<exec::FunctionSignature>> stringToMapSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("map(varchar, varchar)")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build()};
}

} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_regexp_split,
splitSignatures(),
createSplit);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_str_to_map,
stringToMapSignatures(),
createStringToMap);
} // namespace facebook::velox::functions::sparksql
103 changes: 103 additions & 0 deletions velox/functions/sparksql/StringToMap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "folly/container/F14Set.h"
#include "velox/functions/Udf.h"

namespace facebook::velox::functions::sparksql {

template <typename TExecCtx>
struct StringToMapFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExecCtx);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

void call(
out_type<Map<Varchar, Varchar>>& out,
const arg_type<Varchar>& input,
const arg_type<Varchar>& entryDelimiter,
const arg_type<Varchar>& keyValueDelimiter) {
VELOX_USER_CHECK(!entryDelimiter.empty(), "entryDelimiter is empty");
VELOX_USER_CHECK(!keyValueDelimiter.empty(), "keyValueDelimiter is empty");

callImpl(
out,
toStringView(input),
toStringView(entryDelimiter),
toStringView(keyValueDelimiter));
}

private:
static std::string_view toStringView(const arg_type<Varchar>& input) {
return std::string_view(input.data(), input.size());
}

void callImpl(
out_type<Map<Varchar, Varchar>>& out,
std::string_view input,
std::string_view entryDelimiter,
std::string_view keyValueDelimiter) const {
size_t pos = 0;

folly::F14FastSet<std::string_view> keys;

auto nextEntryPos = input.find(entryDelimiter, pos);
while (nextEntryPos != std::string::npos) {
processEntry(
out,
std::string_view(input.data() + pos, nextEntryPos - pos),
keyValueDelimiter,
keys);

pos = nextEntryPos + 1;
nextEntryPos = input.find(entryDelimiter, pos);
}

processEntry(
out,
std::string_view(input.data() + pos, input.size() - pos),
keyValueDelimiter,
keys);
}

void processEntry(
out_type<Map<Varchar, Varchar>>& out,
std::string_view entry,
std::string_view keyValueDelimiter,
folly::F14FastSet<std::string_view>& keys) const {
const auto delimiterPos = entry.find(keyValueDelimiter, 0);
// Not found key/value delimiter.
if (delimiterPos == std::string::npos) {
out.add_null().append(StringView(entry));
return;
}
const auto key = std::string_view(entry.data(), delimiterPos);
VELOX_USER_CHECK(
keys.insert(key).second,
"Duplicate keys are not allowed: ('{}').",
key);
const auto value = StringView(
entry.data() + delimiterPos + 1, entry.size() - delimiterPos - 1);

auto [keyWriter, valueWriter] = out.add_item();
keyWriter.setNoCopy(StringView(key));
valueWriter.setNoCopy(value);
}
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_executable(
SortArrayTest.cpp
SplitFunctionsTest.cpp
StringTest.cpp
StringToMapTest.cpp
UnscaledValueFunctionTest.cpp
XxHash64Test.cpp)

Expand Down
41 changes: 0 additions & 41 deletions velox/functions/sparksql/tests/SplitFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,46 +105,5 @@ TEST_F(SplitTest, longStrings) {
',',
{{{"abcdefghijklkmnopqrstuvwxyz"}}});
}

class StringToMapTest : public SparkFunctionBaseTest {
protected:
void testStringToMap(
const std::vector<StringView>& inputs,
const std::vector<std::pair<StringView, std::optional<StringView>>>&
expect) {
std::vector<VectorPtr> row;
row.emplace_back(makeFlatVector<StringView>({inputs[0]}));
std::string expr =
fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]);
auto result = evaluate<MapVector>(expr, makeRowVector(row));
auto expected = makeMapVector<StringView, StringView>({expect});
assertEqualVectors(result, expected);
}
};

TEST_F(StringToMapTest, Basics) {
testStringToMap(
{"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}});
testStringToMap({"", ",", ":"}, {{"", std::nullopt}});
testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}});
testStringToMap(
{"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"", ",", "="}, {{"", std::nullopt}});
testStringToMap(
{"a::1,b::2,c::3", ",", "c"},
{{"", "::3"}, {"a::1", std::nullopt}, {"b::2", std::nullopt}});
testStringToMap(
{"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
// Same delimiters.
testStringToMap(
{"a:1_b:2_c:3", "_", "_"},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});
// Exception for duplicated keys.
VELOX_ASSERT_THROW(
testStringToMap({"a:1,b:2,a:3", ",", ":"}, {{"a", "3"}, {"b", "2"}}),
"Duplicated keys ('a') are not allowed.");
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
63 changes: 63 additions & 0 deletions velox/functions/sparksql/tests/StringToMapTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/type/Type.h"

namespace facebook::velox::functions::sparksql::test {
using namespace facebook::velox::test;
namespace {
class StringToMapTest : public SparkFunctionBaseTest {
protected:
void testStringToMap(
const std::vector<StringView>& inputs,
const std::vector<std::pair<StringView, std::optional<StringView>>>&
expect) {
std::vector<VectorPtr> row;
row.emplace_back(makeFlatVector<StringView>({inputs[0]}));
std::string expr =
fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]);
auto result = evaluate<MapVector>(expr, makeRowVector(row));
auto expected = makeMapVector<StringView, StringView>({expect});
assertEqualVectors(result, expected);
}
};

TEST_F(StringToMapTest, Basics) {
testStringToMap(
{"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}});
testStringToMap({"", ",", ":"}, {{"", std::nullopt}});
testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}});
testStringToMap(
{"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"", ",", "="}, {{"", std::nullopt}});
testStringToMap(
{"a::1,b::2,c::3", ",", "c"},
{{"", "::3"}, {"a::1", std::nullopt}, {"b::2", std::nullopt}});
testStringToMap(
{"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
// Same delimiters.
testStringToMap(
{"a:1_b:2_c:3", "_", "_"},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});
// Exception for duplicated keys.
VELOX_ASSERT_THROW(
testStringToMap({"a:1,b:2,a:3", ",", ":"}, {{"a", "3"}, {"b", "2"}}),
"Duplicate keys are not allowed: ('a').");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 03e3e49

Please sign in to comment.