Skip to content

Commit

Permalink
Add support for array join on json (#11446)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #11446

Reviewed By: kgpai

Differential Revision: D65507910

Pulled By: HeidiHan0000
  • Loading branch information
HeidiHan0000 authored and facebook-github-bot committed Nov 15, 2024
1 parent 31ae379 commit e353b04
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 2 deletions.
16 changes: 16 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

Expand Down Expand Up @@ -189,6 +191,20 @@ struct ArrayJoinFunction {
result += util::Converter<TypeKind::VARCHAR>::tryCast(value).value();
}

void writeValue(out_type<velox::Varchar>& result, const StringView& value) {
// To VARCHAR converter never throws.
if (isJsonType(arrayElementType_)) {
if (value.size() >= 2 && *value.begin() == '"' &&
*(value.end() - 1) == '"') {
result += util::Converter<TypeKind::VARCHAR>::tryCast(
std::string_view(value.data() + 1, value.size() - 2))
.value();
return;
}
}
result += util::Converter<TypeKind::VARCHAR>::tryCast(value).value();
}

void writeValue(out_type<velox::Varchar>& result, const int32_t& value) {
if (arrayElementType_->isDate()) {
result += util::Converter<TypeKind::VARCHAR>::tryCast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ void registerInternalArrayFunctions() {
}

void registerArrayFunctions(const std::string& prefix) {
registerJsonType();
registerArrayConstructor(prefix + "array_constructor");
VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match");
Expand Down Expand Up @@ -191,6 +192,7 @@ void registerArrayFunctions(const std::string& prefix) {
registerArrayJoinFunctions<Varchar>(prefix);
registerArrayJoinFunctions<Timestamp>(prefix);
registerArrayJoinFunctions<Date>(prefix);
registerArrayJoinFunctions<Json>(prefix);

registerFunction<ArrayAverageFunction, double, Array<double>>(
{prefix + "array_average"});
Expand Down
112 changes: 110 additions & 2 deletions velox/functions/prestosql/tests/ArrayJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <optional>
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/JsonType.h"

using namespace facebook::velox;
using namespace facebook::velox::test;
Expand All @@ -39,12 +40,16 @@ class ArrayJoinTest : public FunctionBaseTest {
std::vector<std::optional<T>> array,
const StringView& delimiter,
const StringView& expected,
bool isDate = false) {
bool isDate = false,
bool isJson = false) {
auto arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array});
if (isDate) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(DATE()));
} else if (isJson) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(JSON()));
}
auto delimiterVector = makeFlatVector<StringView>({delimiter});
auto expectedVector = makeFlatVector<StringView>({expected});
Expand All @@ -58,12 +63,16 @@ class ArrayJoinTest : public FunctionBaseTest {
const StringView& delimiter,
const StringView& replacement,
const StringView& expected,
bool isDate = false) {
bool isDate = false,
bool isJson = false) {
auto arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array});
if (isDate) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(DATE()));
} else if (isJson) {
arrayVector = makeNullableArrayVector(
std::vector<std::vector<std::optional<T>>>{array}, ARRAY(JSON()));
}
auto delimiterVector = makeFlatVector<StringView>({delimiter});
auto replacementVector = makeFlatVector<StringView>({replacement});
Expand Down Expand Up @@ -160,4 +169,103 @@ TEST_F(ArrayJoinTest, dateTest) {
true);
}

TEST_F(ArrayJoinTest, jsonTest) {
setLegacyCast(false);
std::vector<std::optional<StringView>> input{
R"({"one":1,"two":2})",
std::nullopt,
R"(secondElement)",
};
testArrayJoinNoReplacement<StringView>(
input, ", ", R"({"one":1,"two":2}, secondElement)", false, true);
testArrayJoinReplacement<StringView>(
input, ", ", "0", R"({"one":1,"two":2}, 0, secondElement)", false, true);

input = {R"("one element")"};
testArrayJoinNoReplacement<StringView>(
input, ", ", "one element", false, true);
testArrayJoinReplacement<StringView>(
input, ", ", "0", "one element", false, true);

input = {std::nullopt, std::nullopt};
testArrayJoinNoReplacement<StringView>(input, ", ", "", false, true);
testArrayJoinReplacement<StringView>(input, ", ", "0", "0, 0", false, true);

// JSON strings with special characters
input = {
R"({"key": "value\with\backslash"})",
std::nullopt,
R"('value\with\backslash')",
};
testArrayJoinNoReplacement<StringView>(
input,
", ",
R"({"key": "value\with\backslash"}, 'value\with\backslash')",
false,
true);
testArrayJoinReplacement<StringView>(
input,
", ",
"0",
R"({"key": "value\with\backslash"}, 0, 'value\with\backslash')",
false,
true);

input = {
R"({"key": "value\nwith\nnewline"})",
std::nullopt,
R"("value\nwith\nnewline")",
};
testArrayJoinNoReplacement<StringView>(
input,
", ",
R"({"key": "value\nwith\nnewline"}, value\nwith\nnewline)",
false,
true);
testArrayJoinReplacement<StringView>(
input,
", ",
"0",
R"({"key": "value\nwith\nnewline"}, 0, value\nwith\nnewline)",
false,
true);

input = {
R"({"key": "value with \u00A9 and \u20AC"})",
std::nullopt,
R"("value with \u00A9 and \u20AC")",
};
testArrayJoinNoReplacement<StringView>(
input,
", ",
R"({"key": "value with \u00A9 and \u20AC"}, value with \u00A9 and \u20AC)",
false,
true);
testArrayJoinReplacement<StringView>(
input,
", ",
"0",
R"({"key": "value with \u00A9 and \u20AC"}, 0, value with \u00A9 and \u20AC)",
false,
true);

input = {
R"({"key": "!@#$%^&*()_+-={}:<>?,./~`"})",
std::nullopt,
R"("!@#$%^&*()_+-={}:<>?,./~`")",
};
testArrayJoinNoReplacement<StringView>(
input,
", ",
R"({"key": "!@#$%^&*()_+-={}:<>?,./~`"}, !@#$%^&*()_+-={}:<>?,./~`)",
false,
true);
testArrayJoinReplacement<StringView>(
input,
", ",
"0",
R"({"key": "!@#$%^&*()_+-={}:<>?,./~`"}, 0, !@#$%^&*()_+-={}:<>?,./~`)",
false,
true);
}
} // namespace

0 comments on commit e353b04

Please sign in to comment.