From 03e3e495ecf928784c37c09dba7cf9b111f3491d Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Tue, 5 Dec 2023 22:06:31 +0800 Subject: [PATCH] Use simple function api --- velox/functions/sparksql/Register.cpp | 9 +- velox/functions/sparksql/SplitFunctions.cpp | 117 ------------------ velox/functions/sparksql/StringToMap.h | 103 +++++++++++++++ velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/SplitFunctionsTest.cpp | 41 ------ .../sparksql/tests/StringToMapTest.cpp | 63 ++++++++++ 6 files changed, 175 insertions(+), 159 deletions(-) create mode 100644 velox/functions/sparksql/StringToMap.h create mode 100644 velox/functions/sparksql/tests/StringToMapTest.cpp diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 91cd0af6b290c..56612d90fa68e 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -35,6 +35,7 @@ #include "velox/functions/sparksql/RegisterCompare.h" #include "velox/functions/sparksql/Size.h" #include "velox/functions/sparksql/String.h" +#include "velox/functions/sparksql/StringToMap.h" #include "velox/functions/sparksql/UnscaledValueFunction.h" namespace facebook::velox::functions { @@ -144,6 +145,13 @@ void registerFunctions(const std::string& prefix) { int32_t, int32_t>({prefix + "overlay"}); + registerFunction< + sparksql::StringToMapFunction, + Map, + Varchar, + Varchar, + Varchar>({prefix + "str_to_map"}); + registerFunction( {prefix + "left"}); @@ -153,7 +161,6 @@ void registerFunctions(const std::string& prefix) { prefix + "length", lengthSignatures(), makeLength); registerFunction( {prefix + "substring_index"}); - VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map"); registerFunction({prefix + "md5"}); registerFunction( diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index d30ffed81ab66..d845a786e483a 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -134,119 +134,6 @@ std::vector> splitSignatures() { .constantArgumentType("varchar") .build()}; } - -class StringToMap final : public exec::VectorFunction { - public: - StringToMap(char entryDelim, char keyValueDelim) - : entryDelim_(entryDelim), keyValueDelim_(keyValueDelim) {} - - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /* outputType */, - exec::EvalCtx& context, - VectorPtr& result) const override { - exec::DecodedArgs decodedArgs(rows, args, context); - DecodedVector* inputString = decodedArgs.at(0); - - BaseVector::ensureWritable( - rows, MAP(VARCHAR(), VARCHAR()), context.pool(), result); - exec::VectorWriter> resultWriter; - resultWriter.init(*result->as()); - - context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - folly::F14FastSet keys; - resultWriter.setOffset(row); - auto& mapWriter = resultWriter.current(); - - const StringView& current = inputString->valueAt(row); - const char* pos = current.begin(); - const char* end = pos + current.size(); - const char* entryEnd; - const char* keyEnd; - do { - entryEnd = std::find(pos, end, entryDelim_); - keyEnd = std::find(pos, entryEnd, keyValueDelim_); - const auto key = StringView(pos, keyEnd - pos); - VELOX_USER_CHECK( - keys.insert(key).second, - "Duplicated keys ('{}') are not allowed.", - key); - if (keyEnd == entryEnd) { - mapWriter.add_null().append(key); - } else { - auto [keyWriter, valueWriter] = mapWriter.add_item(); - keyWriter.setNoCopy(key); - valueWriter.setNoCopy(StringView(keyEnd + 1, entryEnd - keyEnd - 1)); - } - pos = entryEnd + 1; // Skip past delim. - } while (entryEnd != end); - - resultWriter.commit(); - }); - - resultWriter.finish(); - - // Reuses input buffer. - result->as() - ->mapKeys() - ->as>() - ->acquireSharedStringBuffers(inputString->base()); - result->as() - ->mapValues() - ->as>() - ->acquireSharedStringBuffers(inputString->base()); - } - - private: - const char entryDelim_; - const char keyValueDelim_; -}; - -/// Currently only supports single-character entryDelim & keyValueDelim and -/// these two delimiters must be constant. -std::shared_ptr createStringToMap( - const std::string& /*name*/, - const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { - VELOX_USER_CHECK_EQ( - inputArgs.size(), 3, "Expects 3 arguments for StringToMap."); - - auto getDelimiter = - [](exec::VectorFunctionArg inputArg) -> std::optional { - BaseVector* constantVector = inputArg.constantValue.get(); - VELOX_USER_CHECK_NOT_NULL( - constantVector, - "StringToMap requires constant entry/key-value delimiter."); - const auto constantStringView = - constantVector->as>(); - if (constantStringView->isNullAt(0)) { - return std::nullopt; - } - VELOX_USER_CHECK( - constantStringView->valueAt(0).size() == 1, - "StringToMap only supports single-character entry/key-value delimiter."); - return constantStringView->valueAt(0).data()[0]; - }; - std::optional entryDelim = getDelimiter(inputArgs[1]); - std::optional keyValueDelim = getDelimiter(inputArgs[2]); - // As isDefaultNullBehavior = true. - if (!entryDelim.has_value() || !keyValueDelim.has_value()) { - return std::make_shared(); - } - return std::make_shared(*entryDelim, *keyValueDelim); -} - -// varchar, varchar, varchar -> map(varchar, varchar). -std::vector> stringToMapSignatures() { - return {exec::FunctionSignatureBuilder() - .returnType("map(varchar, varchar)") - .argumentType("varchar") - .argumentType("varchar") - .argumentType("varchar") - .build()}; -} - } // namespace VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( @@ -254,8 +141,4 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( splitSignatures(), createSplit); -VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( - udf_str_to_map, - stringToMapSignatures(), - createStringToMap); } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/StringToMap.h b/velox/functions/sparksql/StringToMap.h new file mode 100644 index 0000000000000..44bc27b3eece4 --- /dev/null +++ b/velox/functions/sparksql/StringToMap.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "folly/container/F14Set.h" +#include "velox/functions/Udf.h" + +namespace facebook::velox::functions::sparksql { + +template +struct StringToMapFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExecCtx); + + // Results refer to strings in the first argument. + static constexpr int32_t reuse_strings_from_arg = 0; + + void call( + out_type>& out, + const arg_type& input, + const arg_type& entryDelimiter, + const arg_type& keyValueDelimiter) { + VELOX_USER_CHECK(!entryDelimiter.empty(), "entryDelimiter is empty"); + VELOX_USER_CHECK(!keyValueDelimiter.empty(), "keyValueDelimiter is empty"); + + callImpl( + out, + toStringView(input), + toStringView(entryDelimiter), + toStringView(keyValueDelimiter)); + } + + private: + static std::string_view toStringView(const arg_type& input) { + return std::string_view(input.data(), input.size()); + } + + void callImpl( + out_type>& out, + std::string_view input, + std::string_view entryDelimiter, + std::string_view keyValueDelimiter) const { + size_t pos = 0; + + folly::F14FastSet keys; + + auto nextEntryPos = input.find(entryDelimiter, pos); + while (nextEntryPos != std::string::npos) { + processEntry( + out, + std::string_view(input.data() + pos, nextEntryPos - pos), + keyValueDelimiter, + keys); + + pos = nextEntryPos + 1; + nextEntryPos = input.find(entryDelimiter, pos); + } + + processEntry( + out, + std::string_view(input.data() + pos, input.size() - pos), + keyValueDelimiter, + keys); + } + + void processEntry( + out_type>& out, + std::string_view entry, + std::string_view keyValueDelimiter, + folly::F14FastSet& keys) const { + const auto delimiterPos = entry.find(keyValueDelimiter, 0); + // Not found key/value delimiter. + if (delimiterPos == std::string::npos) { + out.add_null().append(StringView(entry)); + return; + } + const auto key = std::string_view(entry.data(), delimiterPos); + VELOX_USER_CHECK( + keys.insert(key).second, + "Duplicate keys are not allowed: ('{}').", + key); + const auto value = StringView( + entry.data() + delimiterPos + 1, entry.size() - delimiterPos - 1); + + auto [keyWriter, valueWriter] = out.add_item(); + keyWriter.setNoCopy(StringView(key)); + valueWriter.setNoCopy(value); + } +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 31a68ea7bf264..f96f393b78e0d 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable( SortArrayTest.cpp SplitFunctionsTest.cpp StringTest.cpp + StringToMapTest.cpp UnscaledValueFunctionTest.cpp XxHash64Test.cpp) diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index 1a7a05a3e37cf..4e5023d38584c 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -105,46 +105,5 @@ TEST_F(SplitTest, longStrings) { ',', {{{"abcdefghijklkmnopqrstuvwxyz"}}}); } - -class StringToMapTest : public SparkFunctionBaseTest { - protected: - void testStringToMap( - const std::vector& inputs, - const std::vector>>& - expect) { - std::vector row; - row.emplace_back(makeFlatVector({inputs[0]})); - std::string expr = - fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]); - auto result = evaluate(expr, makeRowVector(row)); - auto expected = makeMapVector({expect}); - assertEqualVectors(result, expected); - } -}; - -TEST_F(StringToMapTest, Basics) { - testStringToMap( - {"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); - testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}}); - testStringToMap({"", ",", ":"}, {{"", std::nullopt}}); - testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}}); - testStringToMap( - {"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); - testStringToMap({"", ",", "="}, {{"", std::nullopt}}); - testStringToMap( - {"a::1,b::2,c::3", ",", "c"}, - {{"", "::3"}, {"a::1", std::nullopt}, {"b::2", std::nullopt}}); - testStringToMap( - {"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); - // Same delimiters. - testStringToMap( - {"a:1_b:2_c:3", "_", "_"}, - {{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}}); - // Exception for duplicated keys. - VELOX_ASSERT_THROW( - testStringToMap({"a:1,b:2,a:3", ",", ":"}, {{"a", "3"}, {"b", "2"}}), - "Duplicated keys ('a') are not allowed."); -} - } // namespace } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/StringToMapTest.cpp b/velox/functions/sparksql/tests/StringToMapTest.cpp new file mode 100644 index 0000000000000..4ce9f21443968 --- /dev/null +++ b/velox/functions/sparksql/tests/StringToMapTest.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql::test { +using namespace facebook::velox::test; +namespace { +class StringToMapTest : public SparkFunctionBaseTest { + protected: + void testStringToMap( + const std::vector& inputs, + const std::vector>>& + expect) { + std::vector row; + row.emplace_back(makeFlatVector({inputs[0]})); + std::string expr = + fmt::format("str_to_map(c0, '{}', '{}')", inputs[1], inputs[2]); + auto result = evaluate(expr, makeRowVector(row)); + auto expected = makeMapVector({expect}); + assertEqualVectors(result, expected); + } +}; + +TEST_F(StringToMapTest, Basics) { + testStringToMap( + {"a:1,b:2,c:3", ",", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + testStringToMap({"a: ,b:2", ",", ":"}, {{"a", " "}, {"b", "2"}}); + testStringToMap({"", ",", ":"}, {{"", std::nullopt}}); + testStringToMap({"a", ",", ":"}, {{"a", std::nullopt}}); + testStringToMap( + {"a=1,b=2,c=3", ",", "="}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + testStringToMap({"", ",", "="}, {{"", std::nullopt}}); + testStringToMap( + {"a::1,b::2,c::3", ",", "c"}, + {{"", "::3"}, {"a::1", std::nullopt}, {"b::2", std::nullopt}}); + testStringToMap( + {"a:1_b:2_c:3", "_", ":"}, {{"a", "1"}, {"b", "2"}, {"c", "3"}}); + // Same delimiters. + testStringToMap( + {"a:1_b:2_c:3", "_", "_"}, + {{"a:1", std::nullopt}, {"b:2", std::nullopt}, {"c:3", std::nullopt}}); + // Exception for duplicated keys. + VELOX_ASSERT_THROW( + testStringToMap({"a:1,b:2,a:3", ",", ":"}, {{"a", "3"}, {"b", "2"}}), + "Duplicate keys are not allowed: ('a')."); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test \ No newline at end of file