From 0e2e2b0d9f475e088b236a90b479bb4042f641cb Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Fri, 11 Oct 2024 14:40:36 -0700 Subject: [PATCH] Add support for custom comparison in map subscript and Presto's element_at UDF (#11239) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11239 Update SubscriptUtil to work with types that provide custom comparison. We can reuse the implementation for complex types, since that just uses the compare function provided by the Vector. With https://github.com/facebookincubator/velox/pull/11022 this just invokes the Type's custom implementation. Since they share the same underlying code this makes both map subscript and the element_at UDF work with custom comparison. Reviewed By: xiaoxmeng Differential Revision: D64256363 fbshipit-source-id: ffd3350f1f00647bc2d397c6ea46fc1715a699f8 --- velox/functions/lib/SubscriptUtil.cpp | 8 +- .../prestosql/tests/ElementAtTest.cpp | 98 +++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/velox/functions/lib/SubscriptUtil.cpp b/velox/functions/lib/SubscriptUtil.cpp index 51931384855f..c01dc64d3cab 100644 --- a/velox/functions/lib/SubscriptUtil.cpp +++ b/velox/functions/lib/SubscriptUtil.cpp @@ -313,7 +313,8 @@ VectorPtr MapSubscript::applyMap( VELOX_CHECK(mapArg->type()->childAt(0)->equivalent(*indexArg->type())); bool triggerCaching = shouldTriggerCaching(mapArg); - if (indexArg->type()->isPrimitiveType()) { + if (indexArg->type()->isPrimitiveType() && + !indexArg->type()->providesCustomComparison()) { return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( applyMapTyped, indexArg->typeKind(), @@ -324,6 +325,11 @@ VectorPtr MapSubscript::applyMap( indexArg, context); } else { + // We use applyMapComplexType when the key type is complex, but also when it + // provides custom comparison operators because the main difference between + // applyMapComplexType and applyTyped is that applyMapComplexType calls the + // Vector's equalValueAt method, which calls the Types custom comparison + // operator internally. return applyMapComplexType( rows, mapArg, indexArg, context, triggerCaching, lookupTable_); } diff --git a/velox/functions/prestosql/tests/ElementAtTest.cpp b/velox/functions/prestosql/tests/ElementAtTest.cpp index 1a7a6b8d003d..64b7f7085543 100644 --- a/velox/functions/prestosql/tests/ElementAtTest.cpp +++ b/velox/functions/prestosql/tests/ElementAtTest.cpp @@ -20,6 +20,7 @@ #include "velox/expression/Expr.h" #include "velox/functions/lib/SubscriptUtil.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/vector/BaseVector.h" #include "velox/vector/SelectivityVector.h" @@ -1329,3 +1330,100 @@ TEST_F(ElementAtTest, testCachingOptimizationComplexKey) { checkStatus(false, true, nullptr); test::assertEqualVectors(resultWithMoreVectors, resultWithMoreVectors1); } + +TEST_F(ElementAtTest, timestampWithTimeZone) { + const auto values = makeFlatVector({1, 2, 3, 4, 5, 6}); + VectorPtr expected = makeNullableFlatVector({3, std::nullopt}); + + auto elementAt = [&](const VectorPtr& map, const VectorPtr& search) { + return evaluate("element_at(C0, C1)", makeRowVector({map, search})); + }; + + // Test elementAt with scalar values. + const auto keys = makeFlatVector( + {pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)}, + TIMESTAMP_WITH_TIME_ZONE()); + const auto mapVector = makeMapVector({0, 3}, keys, values); + test::assertEqualVectors( + expected, + elementAt( + mapVector, + makeFlatVector( + {pack(3, 3), pack(7, 7)}, TIMESTAMP_WITH_TIME_ZONE()))); + test::assertEqualVectors( + expected, + elementAt( + mapVector, + makeFlatVector( + {pack(3, 10), pack(8, 5)}, TIMESTAMP_WITH_TIME_ZONE()))); + + // Test elementAt with TimestampWithTimeZone values embedded in a complex + // type. + const auto rowKeys = makeRowVector({keys}); + const auto mapOfRowKeys = makeMapVector({0, 3}, rowKeys, values); + const auto element = makeRowVector({makeFlatVector( + {pack(-1, 1), pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())}); + expected = makeNullableFlatVector({std::nullopt, 5}); + test::assertEqualVectors(expected, elementAt(mapOfRowKeys, element)); +} + +TEST_F(ElementAtTest, timestampWithTimeZoneWithCaching) { + auto testCaching = [&](std::vector&& args, + const VectorPtr& expected) { + exec::ExprSet exprSet({}, &execCtx_); + const auto inputs = makeRowVector({}); + exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get()); + + const SelectivityVector rows(1); + + facebook::velox::functions::detail::MapSubscript mapSubscriptWithCaching( + true); + + auto checkStatus = [&](bool cachingEnabled, + bool materializedMapIsNull, + const VectorPtr& firstSeen) { + EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled()); + EXPECT_EQ(firstSeen, mapSubscriptWithCaching.firstSeenMap()); + EXPECT_EQ( + materializedMapIsNull, + nullptr == mapSubscriptWithCaching.lookupTable()); + }; + + // Initial state. + checkStatus(true, true, nullptr); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // Nothing has been materialized yet since the input is seen only once. + checkStatus(true, true, args[0]); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // The argument from the previous call should be cached. + checkStatus(true, false, args[0]); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // The map should still be cached because we called it with the same + // argument. + checkStatus(true, false, args[0]); + }; + + // Test elementAt with scalar values and caching. + const auto keys = makeFlatVector( + {pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)}, + TIMESTAMP_WITH_TIME_ZONE()); + const auto values = makeFlatVector({1, 2, 3, 4, 5, 6}); + const auto inputMap = makeMapVector({0}, keys, values); + VectorPtr lookup = makeFlatVector( + std::vector{pack(3, 5)}, TIMESTAMP_WITH_TIME_ZONE()); + testCaching({inputMap, lookup}, makeConstant(3, 1)); + + // Test elementAt with TimestampWithTimeZone values embedded in a complex type + // with caching. + const auto rowKeys = makeRowVector({keys}); + const auto mapOfRowKeys = makeMapVector({0}, rowKeys, values); + lookup = makeRowVector({makeFlatVector( + std::vector{pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())}); + testCaching({mapOfRowKeys, lookup}, makeConstant(5, 1)); +}