Skip to content

Commit

Permalink
Add support for custom comparison in map subscript and Presto's eleme…
Browse files Browse the repository at this point in the history
…nt_at UDF (facebookincubator#11239)

Summary:
Pull Request resolved: facebookincubator#11239

Update SubscriptUtil to work with types that provide custom comparison. We can
reuse the implementation for complex types, since that just uses the compare
function provided by the Vector. With
facebookincubator#11022 this just invokes the Type's
custom implementation.

Since they share the same underlying code this makes both map subscript and
the element_at UDF work with custom comparison.

Reviewed By: xiaoxmeng

Differential Revision: D64256363

fbshipit-source-id: ffd3350f1f00647bc2d397c6ea46fc1715a699f8
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 11, 2024
1 parent 5bedca0 commit 0e2e2b0
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
8 changes: 7 additions & 1 deletion velox/functions/lib/SubscriptUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ VectorPtr MapSubscript::applyMap(
VELOX_CHECK(mapArg->type()->childAt(0)->equivalent(*indexArg->type()));

bool triggerCaching = shouldTriggerCaching(mapArg);
if (indexArg->type()->isPrimitiveType()) {
if (indexArg->type()->isPrimitiveType() &&
!indexArg->type()->providesCustomComparison()) {
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
applyMapTyped,
indexArg->typeKind(),
Expand All @@ -324,6 +325,11 @@ VectorPtr MapSubscript::applyMap(
indexArg,
context);
} else {
// We use applyMapComplexType when the key type is complex, but also when it
// provides custom comparison operators because the main difference between
// applyMapComplexType and applyTyped is that applyMapComplexType calls the
// Vector's equalValueAt method, which calls the Types custom comparison
// operator internally.
return applyMapComplexType(
rows, mapArg, indexArg, context, triggerCaching, lookupTable_);
}
Expand Down
98 changes: 98 additions & 0 deletions velox/functions/prestosql/tests/ElementAtTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "velox/expression/Expr.h"
#include "velox/functions/lib/SubscriptUtil.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
#include "velox/vector/BaseVector.h"
#include "velox/vector/SelectivityVector.h"

Expand Down Expand Up @@ -1329,3 +1330,100 @@ TEST_F(ElementAtTest, testCachingOptimizationComplexKey) {
checkStatus(false, true, nullptr);
test::assertEqualVectors(resultWithMoreVectors, resultWithMoreVectors1);
}

TEST_F(ElementAtTest, timestampWithTimeZone) {
const auto values = makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6});
VectorPtr expected = makeNullableFlatVector<int32_t>({3, std::nullopt});

auto elementAt = [&](const VectorPtr& map, const VectorPtr& search) {
return evaluate("element_at(C0, C1)", makeRowVector({map, search}));
};

// Test elementAt with scalar values.
const auto keys = makeFlatVector<int64_t>(
{pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)},
TIMESTAMP_WITH_TIME_ZONE());
const auto mapVector = makeMapVector({0, 3}, keys, values);
test::assertEqualVectors(
expected,
elementAt(
mapVector,
makeFlatVector<int64_t>(
{pack(3, 3), pack(7, 7)}, TIMESTAMP_WITH_TIME_ZONE())));
test::assertEqualVectors(
expected,
elementAt(
mapVector,
makeFlatVector<int64_t>(
{pack(3, 10), pack(8, 5)}, TIMESTAMP_WITH_TIME_ZONE())));

// Test elementAt with TimestampWithTimeZone values embedded in a complex
// type.
const auto rowKeys = makeRowVector({keys});
const auto mapOfRowKeys = makeMapVector({0, 3}, rowKeys, values);
const auto element = makeRowVector({makeFlatVector<int64_t>(
{pack(-1, 1), pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())});
expected = makeNullableFlatVector<int32_t>({std::nullopt, 5});
test::assertEqualVectors(expected, elementAt(mapOfRowKeys, element));
}

TEST_F(ElementAtTest, timestampWithTimeZoneWithCaching) {
auto testCaching = [&](std::vector<VectorPtr>&& args,
const VectorPtr& expected) {
exec::ExprSet exprSet({}, &execCtx_);
const auto inputs = makeRowVector({});
exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get());

const SelectivityVector rows(1);

facebook::velox::functions::detail::MapSubscript mapSubscriptWithCaching(
true);

auto checkStatus = [&](bool cachingEnabled,
bool materializedMapIsNull,
const VectorPtr& firstSeen) {
EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled());
EXPECT_EQ(firstSeen, mapSubscriptWithCaching.firstSeenMap());
EXPECT_EQ(
materializedMapIsNull,
nullptr == mapSubscriptWithCaching.lookupTable());
};

// Initial state.
checkStatus(true, true, nullptr);

test::assertEqualVectors(
expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx));
// Nothing has been materialized yet since the input is seen only once.
checkStatus(true, true, args[0]);

test::assertEqualVectors(
expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx));
// The argument from the previous call should be cached.
checkStatus(true, false, args[0]);

test::assertEqualVectors(
expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx));
// The map should still be cached because we called it with the same
// argument.
checkStatus(true, false, args[0]);
};

// Test elementAt with scalar values and caching.
const auto keys = makeFlatVector<int64_t>(
{pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)},
TIMESTAMP_WITH_TIME_ZONE());
const auto values = makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6});
const auto inputMap = makeMapVector({0}, keys, values);
VectorPtr lookup = makeFlatVector(
std::vector<int64_t>{pack(3, 5)}, TIMESTAMP_WITH_TIME_ZONE());
testCaching({inputMap, lookup}, makeConstant<int32_t>(3, 1));

// Test elementAt with TimestampWithTimeZone values embedded in a complex type
// with caching.
const auto rowKeys = makeRowVector({keys});
const auto mapOfRowKeys = makeMapVector({0}, rowKeys, values);
lookup = makeRowVector({makeFlatVector(
std::vector<int64_t>{pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())});
testCaching({mapOfRowKeys, lookup}, makeConstant<int32_t>(5, 1));
}

0 comments on commit 0e2e2b0

Please sign in to comment.