Skip to content

Commit

Permalink
Add str_to_map Spark function (#5842)
Browse files Browse the repository at this point in the history
Summary:
Spark implementation.: https://github.com/apache/spark/blob/v3.2.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala#L531

It's similar to `split_to_map` in prestosql which is not implemented.

Pull Request resolved: #5842

Reviewed By: Yuhta

Differential Revision: D51882086

Pulled By: mbasmanova

fbshipit-source-id: ef4444e2c9f74e537091bf4b48d0114ffd0ba405
  • Loading branch information
majian4work authored and facebook-github-bot committed Dec 6, 2023
1 parent 2e71d8e commit af7d1ef
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 0 deletions.
14 changes: 14 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,20 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
SELECT startswith('js SQL', 'SQL'); -- false
SELECT startswith('js SQL', null); -- NULL

.. spark:function:: str_to_map(string, entryDelimiter, keyValueDelimiter) -> map(string, string)
Returns a map by splitting ``string`` into entries with ``entryDelimiter`` and splitting
each entry into key/value with ``keyValueDelimiter``.
``entryDelimiter`` and ``keyValueDelimiter`` must be constant strings with single ascii
character. Allows ``keyValueDelimiter`` not found when splitting an entry. Throws exception
when duplicate map keys are found for single row's result, consistent with Spark's default
behavior. ::

SELECT str_to_map('a:1,b:2,c:3', ',', ':'); -- {"a":"1","b":"2","c":"3"}
SELECT str_to_map('a', ',', ':'); -- {"a":NULL}
SELECT str_to_map('', ',', ':'); -- {"":NULL}
SELECT str_to_map('a:1,b:2,c:3', ',', ','); -- {"a:1":NULL,"b:2":NULL,"c:3":NULL}

.. spark:function:: substring(string, start) -> varchar
Returns the rest of ``string`` from the starting position ``start``.
Expand Down
8 changes: 8 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "velox/functions/sparksql/RegisterCompare.h"
#include "velox/functions/sparksql/Size.h"
#include "velox/functions/sparksql/String.h"
#include "velox/functions/sparksql/StringToMap.h"
#include "velox/functions/sparksql/UnscaledValueFunction.h"
#include "velox/functions/sparksql/specialforms/DecimalRound.h"
#include "velox/functions/sparksql/specialforms/MakeDecimal.h"
Expand Down Expand Up @@ -153,6 +154,13 @@ void registerFunctions(const std::string& prefix) {
int32_t,
int32_t>({prefix + "overlay"});

registerFunction<
sparksql::StringToMapFunction,
Map<Varchar, Varchar>,
Varchar,
Varchar,
Varchar>({prefix + "str_to_map"});

registerFunction<sparksql::LeftFunction, Varchar, Varchar, int32_t>(
{prefix + "left"});

Expand Down
102 changes: 102 additions & 0 deletions velox/functions/sparksql/StringToMap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "folly/container/F14Set.h"
#include "velox/functions/Udf.h"

namespace facebook::velox::functions::sparksql {

template <typename T>
struct StringToMapFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

void call(
out_type<Map<Varchar, Varchar>>& out,
const arg_type<Varchar>& input,
const arg_type<Varchar>& entryDelimiter,
const arg_type<Varchar>& keyValueDelimiter) {
VELOX_USER_CHECK_EQ(
entryDelimiter.size(), 1, "entryDelimiter's size should be 1.");
VELOX_USER_CHECK_EQ(
keyValueDelimiter.size(), 1, "keyValueDelimiter's size should be 1.");

callImpl(
out,
toStringView(input),
toStringView(entryDelimiter),
toStringView(keyValueDelimiter));
}

private:
static std::string_view toStringView(const arg_type<Varchar>& input) {
return std::string_view(input.data(), input.size());
}

void callImpl(
out_type<Map<Varchar, Varchar>>& out,
std::string_view input,
std::string_view entryDelimiter,
std::string_view keyValueDelimiter) const {
size_t pos = 0;
folly::F14FastSet<std::string_view> keys;

auto nextEntryPos = input.find(entryDelimiter, pos);
while (nextEntryPos != std::string::npos) {
processEntry(
out,
std::string_view(input.data() + pos, nextEntryPos - pos),
keyValueDelimiter,
keys);

pos = nextEntryPos + 1;
nextEntryPos = input.find(entryDelimiter, pos);
}

processEntry(
out,
std::string_view(input.data() + pos, input.size() - pos),
keyValueDelimiter,
keys);
}

void processEntry(
out_type<Map<Varchar, Varchar>>& out,
std::string_view entry,
std::string_view keyValueDelimiter,
folly::F14FastSet<std::string_view>& keys) const {
const auto delimiterPos = entry.find(keyValueDelimiter, 0);
// Allows keyValue delimiter not found.
if (delimiterPos == std::string::npos) {
out.add_null().setNoCopy(StringView(entry));
return;
}
const auto key = std::string_view(entry.data(), delimiterPos);
VELOX_USER_CHECK(
keys.insert(key).second, "Duplicate keys are not allowed: '{}'.", key);
const auto value = StringView(
entry.data() + delimiterPos + 1, entry.size() - delimiterPos - 1);

auto [keyWriter, valueWriter] = out.add_item();
keyWriter.setNoCopy(StringView(key));
valueWriter.setNoCopy(value);
}
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_executable(
SortArrayTest.cpp
SplitFunctionsTest.cpp
StringTest.cpp
StringToMapTest.cpp
UnscaledValueFunctionTest.cpp
XxHash64Test.cpp)

Expand Down
98 changes: 98 additions & 0 deletions velox/functions/sparksql/tests/StringToMapTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {
class StringToMapTest : public SparkFunctionBaseTest {
protected:
VectorPtr evaluateStringToMap(const std::vector<StringView>& inputs) {
const std::string expr =
fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]);
return evaluate<MapVector>(
expr, makeRowVector({makeFlatVector<StringView>({inputs[0]})}));
}

void testStringToMap(
const std::vector<StringView>& inputs,
const std::vector<std::pair<StringView, std::optional<StringView>>>&
expect) {
auto result = evaluateStringToMap(inputs);
auto expectVector = makeMapVector<StringView, StringView>({expect});
assertEqualVectors(result, expectVector);
}
};

TEST_F(StringToMapTest, basic) {
testStringToMap(
{"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}});
testStringToMap({"a:,b:2", ",", ":"}, {{"a", ""}, {"b", "2"}});
testStringToMap({"", ",", ":"}, {{"", std::nullopt}});
testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}});
testStringToMap(
{"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});
testStringToMap({"", ",", "="}, {{"", std::nullopt}});
testStringToMap(
{"a::1,b::2,c::3", ",", "c"},
{{"a::1", std::nullopt}, {"b::2", std::nullopt}, {"", "::3"}});
testStringToMap(
{"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}});

// Same delimiters.
testStringToMap(
{"a:1,b:2,c:3", ",", ","},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});
testStringToMap(
{"a:1_b:2_c:3", "_", "_"},
{{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}});

// Exception for illegal delimiters.
// Empty string is used.
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", "", ":"}),
"entryDelimiter's size should be 1.");
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", ",", ""}),
"keyValueDelimiter's size should be 1.");
// Delimiter's length > 1.
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", ";;", ":"}),
"entryDelimiter's size should be 1.");
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", ",", "::"}),
"keyValueDelimiter's size should be 1.");
// Unicode character is used.
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", "å", ":"}),
"entryDelimiter's size should be 1.");
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2", ",", "æ"}),
"keyValueDelimiter's size should be 1.");

// Exception for duplicated keys.
VELOX_ASSERT_THROW(
evaluateStringToMap({"a:1,b:2,a:3", ",", ":"}),
"Duplicate keys are not allowed: 'a'.");
VELOX_ASSERT_THROW(
evaluateStringToMap({":1,:2", ",", ":"}),
"Duplicate keys are not allowed: ''.");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit af7d1ef

Please sign in to comment.