Skip to content

Commit

Permalink
support spark str_to_map
Browse files Browse the repository at this point in the history
  • Loading branch information
majian4work committed Aug 1, 2023
1 parent 517e3e3 commit 8462b33
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 2 deletions.
10 changes: 9 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,12 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL
.. spark:function:: str_to_map(text[, pairDelim[, keyValueDelim]]) -> map(string, string)
Create a map after splitting the ``text`` into key/value pairs using delimiters.
Default delimiters are ',' for `pairDelim` and ':' for `keyValueDelim`.
Both `pairDelim` and `keyValueDelim` are treated as regular expressions.::

SELECT str_to_map('a:1,b:2,c:3', ',', ':') -- {"a":"1","b":"2","c":"3"}
2 changes: 1 addition & 1 deletion velox/expression/ComplexWriterTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ class MapWriter {

std::tuple<PrimitiveWriter<K, false>, PrimitiveWriter<V>> operator[](
vector_size_t index) {
static_assert(std_interface, "operator [] not allowed for this map");
// static_assert(std_interface, "operator [] not allowed for this map");
VELOX_DCHECK_LT(index, length_, "out of bound access");
return {
PrimitiveWriter<K, false>{keysVector_, innerOffset_ + index},
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void registerFunctions(const std::string& prefix) {
prefix + "instr", instrSignatures(), makeInstr);
exec::registerStatefulVectorFunction(
prefix + "length", lengthSignatures(), makeLength);
VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map");

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
119 changes: 119 additions & 0 deletions velox/functions/sparksql/SplitFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@
* limitations under the License.
*/

#include <iostream>
#include <memory>
#include <optional>
#include <unordered_map>
#include <utility>

#include "expression/DecodedArgs.h"
#include "type/StringView.h"
#include "type/Type.h"
#include "vector/ComplexVector.h"
#include "vector/ConstantVector.h"
#include "vector/TypeAliases.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorWriters.h"

Expand Down Expand Up @@ -134,10 +144,119 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
.build()};
}

// str_to_map(expr [, pairDelim [, keyValueDelim] ] )
class StrToMap final : public exec::VectorFunction {
public:
StrToMap() = default;

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
exec::DecodedArgs decodedArgs(rows, args, context);
DecodedVector* strings = decodedArgs.at(0);
char pairDelim = ',';
char kvDelim = ':';
VELOX_CHECK(
!args.empty(),
"StrToMap function should provide at least one argument");
if (args.size() > 1) {
pairDelim = args[1]->as<SimpleVector<StringView>>()->valueAt(0).data()[0];
if (args.size() > 2) {
kvDelim = args[2]->as<SimpleVector<StringView>>()->valueAt(0).data()[0];
}
}

BaseVector::ensureWritable(
rows, MAP(VARCHAR(), VARCHAR()), context.pool(), result);
exec::VectorWriter<Map<Varchar, Varchar>> resultWriter;
resultWriter.init(*result->as<MapVector>());

std::unordered_map<StringView, vector_size_t> keyToIdx;
rows.applyToSelected([&](vector_size_t row) {
resultWriter.setOffset(row);
auto& mapWriter = resultWriter.current();

const StringView& current = strings->valueAt<StringView>(row);
const char* pos = current.begin();
const char* end = pos + current.size();
const char* pair;
const char* kv;
do {
pair = std::find(pos, end, pairDelim);
kv = std::find(pos, pair, kvDelim);
auto key = StringView(pos, kv - pos);
auto iter = keyToIdx.find(key);
if (iter == keyToIdx.end()) {
keyToIdx.emplace(key, mapWriter.size());
if (kv == pair) {
mapWriter.add_null().append(key);
} else {
auto [keyWriter, valueWriter] = mapWriter.add_item();
keyWriter.append(key);
valueWriter.append(StringView(kv + 1, pair - kv - 1));
}
} else {
auto valueWriter = std::get<1>(mapWriter[iter->second]);
if (kv == pair) {
valueWriter = std::nullopt;
} else {
valueWriter = StringView(kv + 1, pair - kv - 1);
}
}

pos = pair + 1; // Skip past delim.
} while (pair != end);

resultWriter.commit();
});

resultWriter.finish();

// Ensure that our result elements vector uses the same string buffer as
// the input vector of strings.
result->as<MapVector>()
->mapKeys()
->as<FlatVector<StringView>>()
->acquireSharedStringBuffers(strings->base());
result->as<MapVector>()
->mapValues()
->as<FlatVector<StringView>>()
->acquireSharedStringBuffers(strings->base());
}
};

std::vector<std::shared_ptr<exec::FunctionSignature>> strToMapSignatures() {
// varchar, varchar -> array(varchar)
return {
exec::FunctionSignatureBuilder()
.returnType("map(varchar, varchar)")
.argumentType("varchar")
.build(),
exec::FunctionSignatureBuilder()
.returnType("map(varchar, varchar)")
.argumentType("varchar")
.argumentType("varchar")
.build(),
exec::FunctionSignatureBuilder()
.returnType("map(varchar, varchar)")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build()};
}

} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_regexp_split,
signatures(),
createSplit);

VELOX_DECLARE_VECTOR_FUNCTION(
udf_str_to_map,
strToMapSignatures(),
std::make_unique<StrToMap>());
} // namespace facebook::velox::functions::sparksql
72 changes: 72 additions & 0 deletions velox/functions/sparksql/tests/SplitFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <type/Type.h>
#include <optional>
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

Expand Down Expand Up @@ -105,5 +106,76 @@ TEST_F(SplitTest, longStrings) {
{{{"abcdefghijklkmnopqrstuvwxyz"}}});
}

class StrToMapTest : public SparkFunctionBaseTest {
protected:
void testStrToMap(
const std::vector<StringView>& parameters,
const std::vector<std::pair<StringView, std::optional<StringView>>>&
expect) {
std::string expr;
switch (parameters.size()) {
case 1:
expr = "str_to_map(c0)";
break;
case 2:
expr = "str_to_map(c0, c1)";
break;
case 3:
expr = "str_to_map(c0, c1, c2)";
break;
default:
VELOX_FAIL("Unsupported arguments size");
}
std::vector<VectorPtr> row;
row.reserve(parameters.size());
for (auto parameter : parameters) {
row.emplace_back(makeFlatVector<StringView>({parameter}));
}
auto result = evaluate<MapVector>(expr, makeRowVector(row));
auto expected = makeMapVector<StringView, StringView>({expect});
::facebook::velox::test::assertEqualVectors(result, expected);
}
};

TEST_F(StrToMapTest, Basics) {
// clang-format off
auto table = std::vector<std::tuple<
std::vector<StringView>,
std::vector<std::pair<StringView, std::optional<StringView>>>>>{
{
{"a:1,b:2,c:3"},
{{"a", "1"}, {"b", "2"}, {"c", "3"}}
},
{
{"a: ,b:2"},
{{"a", " "}, {"b", "2"}}
},
{
{"a=1,b=2,c=3", ",", "="},
{{"a", "1"}, {"b", "2"}, {"c", "3"}}
},
{
{"", ",", "="},
{{"", std::nullopt}}
},
{
{"a:1_b:2_c:3", "_"},
{{"a", "1"}, {"b", "2"}, {"c", "3"}}
},
{
{"a"},
{{"a", std::nullopt}}
},
{
{"a:1,b:2,a:3"},
{{"a", "3"}, {"b", "2"}}
},
};
// clang-format off
for (const auto& [input, expect] : table) {
testStrToMap(input, expect);
}
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 8462b33

Please sign in to comment.