diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index c1f433a53eced..e83d9bcad52e9 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -213,4 +213,12 @@ Unless specified otherwise, all functions return NULL if at least one of the arg Returns string with all characters changed to uppercase. :: - SELECT upper('SparkSql'); -- SPARKSQL \ No newline at end of file + SELECT upper('SparkSql'); -- SPARKSQL + +.. spark:function:: str_to_map(text[, pairDelim[, keyValueDelim]]) -> map(string, string) + + Create a map after splitting the ``text`` into key/value pairs using delimiters. + Default delimiters are ',' for `pairDelim` and ':' for `keyValueDelim`. + Both `pairDelim` and `keyValueDelim` are treated as regular expressions.:: + + SELECT str_to_map('a:1,b:2,c:3', ',', ':') -- {"a":"1","b":"2","c":"3"} \ No newline at end of file diff --git a/velox/expression/ComplexWriterTypes.h b/velox/expression/ComplexWriterTypes.h index 1ab6a242e10a5..fd99e7ee5e42b 100644 --- a/velox/expression/ComplexWriterTypes.h +++ b/velox/expression/ComplexWriterTypes.h @@ -428,7 +428,7 @@ class MapWriter { std::tuple, PrimitiveWriter> operator[]( vector_size_t index) { - static_assert(std_interface, "operator [] not allowed for this map"); + // static_assert(std_interface, "operator [] not allowed for this map"); VELOX_DCHECK_LT(index, length_, "out of bound access"); return { PrimitiveWriter{keysVector_, innerOffset_ + index}, diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index b8cc01b75e946..67004260d67f9 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -119,6 +119,7 @@ void registerFunctions(const std::string& prefix) { prefix + "instr", instrSignatures(), makeInstr); exec::registerStatefulVectorFunction( prefix + "length", lengthSignatures(), makeLength); + VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map"); registerFunction({prefix + "md5"}); registerFunction( diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index 4d092e6928373..d113fcff5ca7e 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -14,8 +14,18 @@ * limitations under the License. */ +#include +#include +#include +#include #include +#include "expression/DecodedArgs.h" +#include "type/StringView.h" +#include "type/Type.h" +#include "vector/ComplexVector.h" +#include "vector/ConstantVector.h" +#include "vector/TypeAliases.h" #include "velox/expression/VectorFunction.h" #include "velox/expression/VectorWriters.h" @@ -134,10 +144,119 @@ std::vector> signatures() { .build()}; } +// str_to_map(expr [, pairDelim [, keyValueDelim] ] ) +class StrToMap final : public exec::VectorFunction { + public: + StrToMap() = default; + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& /* outputType */, + exec::EvalCtx& context, + VectorPtr& result) const override { + exec::DecodedArgs decodedArgs(rows, args, context); + DecodedVector* strings = decodedArgs.at(0); + char pairDelim = ','; + char kvDelim = ':'; + VELOX_CHECK( + !args.empty(), + "StrToMap function should provide at least one argument"); + if (args.size() > 1) { + pairDelim = args[1]->as>()->valueAt(0).data()[0]; + if (args.size() > 2) { + kvDelim = args[2]->as>()->valueAt(0).data()[0]; + } + } + + BaseVector::ensureWritable( + rows, MAP(VARCHAR(), VARCHAR()), context.pool(), result); + exec::VectorWriter> resultWriter; + resultWriter.init(*result->as()); + + std::unordered_map keyToIdx; + rows.applyToSelected([&](vector_size_t row) { + resultWriter.setOffset(row); + auto& mapWriter = resultWriter.current(); + + const StringView& current = strings->valueAt(row); + const char* pos = current.begin(); + const char* end = pos + current.size(); + const char* pair; + const char* kv; + do { + pair = std::find(pos, end, pairDelim); + kv = std::find(pos, pair, kvDelim); + auto key = StringView(pos, kv - pos); + auto iter = keyToIdx.find(key); + if (iter == keyToIdx.end()) { + keyToIdx.emplace(key, mapWriter.size()); + if (kv == pair) { + mapWriter.add_null().append(key); + } else { + auto [keyWriter, valueWriter] = mapWriter.add_item(); + keyWriter.append(key); + valueWriter.append(StringView(kv + 1, pair - kv - 1)); + } + } else { + auto valueWriter = std::get<1>(mapWriter[iter->second]); + if (kv == pair) { + valueWriter = std::nullopt; + } else { + valueWriter = StringView(kv + 1, pair - kv - 1); + } + } + + pos = pair + 1; // Skip past delim. + } while (pair != end); + + resultWriter.commit(); + }); + + resultWriter.finish(); + + // Ensure that our result elements vector uses the same string buffer as + // the input vector of strings. + result->as() + ->mapKeys() + ->as>() + ->acquireSharedStringBuffers(strings->base()); + result->as() + ->mapValues() + ->as>() + ->acquireSharedStringBuffers(strings->base()); + } +}; + +std::vector> strToMapSignatures() { + // varchar, varchar -> array(varchar) + return { + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .build(), + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .argumentType("varchar") + .build(), + exec::FunctionSignatureBuilder() + .returnType("map(varchar, varchar)") + .argumentType("varchar") + .argumentType("varchar") + .argumentType("varchar") + .build()}; +} + } // namespace VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_regexp_split, signatures(), createSplit); + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_str_to_map, + strToMapSignatures(), + std::make_unique()); } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index 8928849a44ce2..6798bf11957e2 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" @@ -105,5 +106,76 @@ TEST_F(SplitTest, longStrings) { {{{"abcdefghijklkmnopqrstuvwxyz"}}}); } +class StrToMapTest : public SparkFunctionBaseTest { + protected: + void testStrToMap( + const std::vector& parameters, + const std::vector>>& + expect) { + std::string expr; + switch (parameters.size()) { + case 1: + expr = "str_to_map(c0)"; + break; + case 2: + expr = "str_to_map(c0, c1)"; + break; + case 3: + expr = "str_to_map(c0, c1, c2)"; + break; + default: + VELOX_FAIL("Unsupported arguments size"); + } + std::vector row; + row.reserve(parameters.size()); + for (auto parameter : parameters) { + row.emplace_back(makeFlatVector({parameter})); + } + auto result = evaluate(expr, makeRowVector(row)); + auto expected = makeMapVector({expect}); + ::facebook::velox::test::assertEqualVectors(result, expected); + } +}; + +TEST_F(StrToMapTest, Basics) { + // clang-format off + auto table = std::vector, + std::vector>>>>{ + { + {"a:1,b:2,c:3"}, + {{"a", "1"}, {"b", "2"}, {"c", "3"}} + }, + { + {"a: ,b:2"}, + {{"a", " "}, {"b", "2"}} + }, + { + {"a=1,b=2,c=3", ",", "="}, + {{"a", "1"}, {"b", "2"}, {"c", "3"}} + }, + { + {"", ",", "="}, + {{"", std::nullopt}} + }, + { + {"a:1_b:2_c:3", "_"}, + {{"a", "1"}, {"b", "2"}, {"c", "3"}} + }, + { + {"a"}, + {{"a", std::nullopt}} + }, + { + {"a:1,b:2,a:3"}, + {{"a", "3"}, {"b", "2"}} + }, + }; + // clang-format off + for (const auto& [input, expect] : table) { + testStrToMap(input, expect); + } +} + } // namespace } // namespace facebook::velox::functions::sparksql::test