From 3d08dd702a620d59d5307d94e2da74b72b8ca199 Mon Sep 17 00:00:00 2001 From: Ma-Jian1 Date: Wed, 26 Jul 2023 10:04:49 +0800 Subject: [PATCH] support spark str_to_map --- velox/docs/functions/spark/string.rst | 13 +++ velox/functions/sparksql/Register.cpp | 8 ++ velox/functions/sparksql/StringToMap.h | 103 ++++++++++++++++++ velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/StringToMapTest.cpp | 66 +++++++++++ 5 files changed, 191 insertions(+) create mode 100644 velox/functions/sparksql/StringToMap.h create mode 100644 velox/functions/sparksql/tests/StringToMapTest.cpp diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 2959943cb5c5..8c3d9d2f9d99 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -178,6 +178,19 @@ Unless specified otherwise, all functions return NULL if at least one of the arg SELECT startswith('js SQL', 'SQL'); -- false SELECT startswith('js SQL', null); -- NULL +.. spark:function:: str_to_map(string, entryDelim, keyValueDelim) -> map(string, string) + + Returns a map by splitting ``string`` into entries with ``entryDelim`` and splitting + each entry into key/value with ``keyValueDelim``. + Only supports constant single-character ``entryDelim`` and ``keyValueDelim``. Throws + exception when duplicate map keys are found for single row's result, consistent with + Spark's default behavior. :: + + SELECT str_to_map('a:1,b:2,c:3', ',', ':'); -- {"a":"1","b":"2","c":"3"} + SELECT str_to_map('a', ',', ':'); -- {'a':NULL} + SELECT str_to_map('', ',', ':'); -- {'':NULL} + SELECT str_to_map('a:1,b:2,c:3', ',', ','); -- {"a:1":NULL,"b:2":NULL,"c:3":NULL} + .. spark:function:: substring(string, start) -> varchar Returns the rest of ``string`` from the starting position ``start``. diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index fb896fd29e75..98694da86cc1 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -36,6 +36,7 @@ #include "velox/functions/sparksql/RegisterCompare.h" #include "velox/functions/sparksql/Size.h" #include "velox/functions/sparksql/String.h" +#include "velox/functions/sparksql/StringToMap.h" #include "velox/functions/sparksql/UnscaledValueFunction.h" #include "velox/functions/sparksql/specialforms/DecimalRound.h" #include "velox/functions/sparksql/specialforms/MakeDecimal.h" @@ -153,6 +154,13 @@ void registerFunctions(const std::string& prefix) { int32_t, int32_t>({prefix + "overlay"}); + registerFunction< + sparksql::StringToMapFunction, + Map, + Varchar, + Varchar, + Varchar>({prefix + "str_to_map"}); + registerFunction( {prefix + "left"}); diff --git a/velox/functions/sparksql/StringToMap.h b/velox/functions/sparksql/StringToMap.h new file mode 100644 index 000000000000..cc159b5752dd --- /dev/null +++ b/velox/functions/sparksql/StringToMap.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "folly/container/F14Set.h" +#include "velox/functions/Udf.h" + +namespace facebook::velox::functions::sparksql { + +template +struct StringToMapFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExecCtx); + + // Results refer to strings in the first argument. + static constexpr int32_t reuse_strings_from_arg = 0; + + void call( + out_type>& out, + const arg_type& input, + const arg_type& entryDelimiter, + const arg_type& keyValueDelimiter) { + VELOX_USER_CHECK(!entryDelimiter.empty(), "entryDelimiter is empty"); + VELOX_USER_CHECK(!keyValueDelimiter.empty(), "keyValueDelimiter is empty"); + + callImpl( + out, + toStringView(input), + toStringView(entryDelimiter), + toStringView(keyValueDelimiter)); + } + + private: + static std::string_view toStringView(const arg_type& input) { + return std::string_view(input.data(), input.size()); + } + + void callImpl( + out_type>& out, + std::string_view input, + std::string_view entryDelimiter, + std::string_view keyValueDelimiter) const { + size_t pos = 0; + + folly::F14FastSet keys; + + auto nextEntryPos = input.find(entryDelimiter, pos); + while (nextEntryPos != std::string::npos) { + processEntry( + out, + std::string_view(input.data() + pos, nextEntryPos - pos), + keyValueDelimiter, + keys); + + pos = nextEntryPos + 1; + nextEntryPos = input.find(entryDelimiter, pos); + } + + processEntry( + out, + std::string_view(input.data() + pos, input.size() - pos), + keyValueDelimiter, + keys); + } + + void processEntry( + out_type>& out, + std::string_view entry, + std::string_view keyValueDelimiter, + folly::F14FastSet& keys) const { + const auto delimiterPos = entry.find(keyValueDelimiter, 0); + // Not found key/value delimiter. + if (delimiterPos == std::string::npos) { + out.add_null().setNoCopy(StringView(entry)); + return; + } + const auto key = std::string_view(entry.data(), delimiterPos); + VELOX_USER_CHECK( + keys.insert(key).second, + "Duplicate keys are not allowed: ('{}').", + key); + const auto value = StringView( + entry.data() + delimiterPos + 1, entry.size() - delimiterPos - 1); + + auto [keyWriter, valueWriter] = out.add_item(); + keyWriter.setNoCopy(StringView(key)); + valueWriter.setNoCopy(value); + } +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 7e7650e12173..49b63c10d815 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -38,6 +38,7 @@ add_executable( SortArrayTest.cpp SplitFunctionsTest.cpp StringTest.cpp + StringToMapTest.cpp UnscaledValueFunctionTest.cpp XxHash64Test.cpp) diff --git a/velox/functions/sparksql/tests/StringToMapTest.cpp b/velox/functions/sparksql/tests/StringToMapTest.cpp new file mode 100644 index 000000000000..571aa6d2faa2 --- /dev/null +++ b/velox/functions/sparksql/tests/StringToMapTest.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql::test { +using namespace facebook::velox::test; +namespace { +class StringToMapTest : public SparkFunctionBaseTest { + protected: + void testStringToMap( + const std::vector& inputs, + const std::vector>>& + expect) { + std::vector row; + row.emplace_back(makeFlatVector({inputs[0]})); + std::string expr = + fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]); + auto result = evaluate(expr, makeRowVector(row)); + auto expected = makeMapVector({expect}); + assertEqualVectors(result, expected); + } +}; + +TEST_F(StringToMapTest, Basics) { + testStringToMap( + {"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}}); + testStringToMap({"", ",", ":"}, {{"", std::nullopt}}); + testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}}); + testStringToMap( + {"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + testStringToMap({"", ",", "="}, {{"", std::nullopt}}); + testStringToMap( + {"a::1,b::2,c::3", ",", "c"}, + {{"", "::3"}, {"a::1", std::nullopt}, {"b::2", std::nullopt}}); + testStringToMap( + {"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + // Same delimiters. + testStringToMap( + {"a:1,b:2,c:3", ",", ","}, + {{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}}); + testStringToMap( + {"a:1_b:2_c:3", "_", "_"}, + {{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}}); + // Exception for duplicated keys. + VELOX_ASSERT_THROW( + testStringToMap({"a:1,b:2,a:3", ",", ":"}, {{"a", "3"}, {"b", "2"}}), + "Duplicate keys are not allowed: ('a')."); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test