diff --git a/velox/functions/lib/SubscriptUtil.cpp b/velox/functions/lib/SubscriptUtil.cpp index 51931384855f..c01dc64d3cab 100644 --- a/velox/functions/lib/SubscriptUtil.cpp +++ b/velox/functions/lib/SubscriptUtil.cpp @@ -313,7 +313,8 @@ VectorPtr MapSubscript::applyMap( VELOX_CHECK(mapArg->type()->childAt(0)->equivalent(*indexArg->type())); bool triggerCaching = shouldTriggerCaching(mapArg); - if (indexArg->type()->isPrimitiveType()) { + if (indexArg->type()->isPrimitiveType() && + !indexArg->type()->providesCustomComparison()) { return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( applyMapTyped, indexArg->typeKind(), @@ -324,6 +325,11 @@ VectorPtr MapSubscript::applyMap( indexArg, context); } else { + // We use applyMapComplexType when the key type is complex, but also when it + // provides custom comparison operators because the main difference between + // applyMapComplexType and applyTyped is that applyMapComplexType calls the + // Vector's equalValueAt method, which calls the Types custom comparison + // operator internally. return applyMapComplexType( rows, mapArg, indexArg, context, triggerCaching, lookupTable_); } diff --git a/velox/functions/prestosql/tests/ElementAtTest.cpp b/velox/functions/prestosql/tests/ElementAtTest.cpp index 1a7a6b8d003d..64b7f7085543 100644 --- a/velox/functions/prestosql/tests/ElementAtTest.cpp +++ b/velox/functions/prestosql/tests/ElementAtTest.cpp @@ -20,6 +20,7 @@ #include "velox/expression/Expr.h" #include "velox/functions/lib/SubscriptUtil.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/vector/BaseVector.h" #include "velox/vector/SelectivityVector.h" @@ -1329,3 +1330,100 @@ TEST_F(ElementAtTest, testCachingOptimizationComplexKey) { checkStatus(false, true, nullptr); test::assertEqualVectors(resultWithMoreVectors, resultWithMoreVectors1); } + +TEST_F(ElementAtTest, timestampWithTimeZone) { + const auto values = makeFlatVector({1, 2, 3, 4, 5, 6}); + VectorPtr expected = makeNullableFlatVector({3, std::nullopt}); + + auto elementAt = [&](const VectorPtr& map, const VectorPtr& search) { + return evaluate("element_at(C0, C1)", makeRowVector({map, search})); + }; + + // Test elementAt with scalar values. + const auto keys = makeFlatVector( + {pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)}, + TIMESTAMP_WITH_TIME_ZONE()); + const auto mapVector = makeMapVector({0, 3}, keys, values); + test::assertEqualVectors( + expected, + elementAt( + mapVector, + makeFlatVector( + {pack(3, 3), pack(7, 7)}, TIMESTAMP_WITH_TIME_ZONE()))); + test::assertEqualVectors( + expected, + elementAt( + mapVector, + makeFlatVector( + {pack(3, 10), pack(8, 5)}, TIMESTAMP_WITH_TIME_ZONE()))); + + // Test elementAt with TimestampWithTimeZone values embedded in a complex + // type. + const auto rowKeys = makeRowVector({keys}); + const auto mapOfRowKeys = makeMapVector({0, 3}, rowKeys, values); + const auto element = makeRowVector({makeFlatVector( + {pack(-1, 1), pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())}); + expected = makeNullableFlatVector({std::nullopt, 5}); + test::assertEqualVectors(expected, elementAt(mapOfRowKeys, element)); +} + +TEST_F(ElementAtTest, timestampWithTimeZoneWithCaching) { + auto testCaching = [&](std::vector&& args, + const VectorPtr& expected) { + exec::ExprSet exprSet({}, &execCtx_); + const auto inputs = makeRowVector({}); + exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get()); + + const SelectivityVector rows(1); + + facebook::velox::functions::detail::MapSubscript mapSubscriptWithCaching( + true); + + auto checkStatus = [&](bool cachingEnabled, + bool materializedMapIsNull, + const VectorPtr& firstSeen) { + EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled()); + EXPECT_EQ(firstSeen, mapSubscriptWithCaching.firstSeenMap()); + EXPECT_EQ( + materializedMapIsNull, + nullptr == mapSubscriptWithCaching.lookupTable()); + }; + + // Initial state. + checkStatus(true, true, nullptr); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // Nothing has been materialized yet since the input is seen only once. + checkStatus(true, true, args[0]); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // The argument from the previous call should be cached. + checkStatus(true, false, args[0]); + + test::assertEqualVectors( + expected, mapSubscriptWithCaching.applyMap(rows, args, evalCtx)); + // The map should still be cached because we called it with the same + // argument. + checkStatus(true, false, args[0]); + }; + + // Test elementAt with scalar values and caching. + const auto keys = makeFlatVector( + {pack(1, 1), pack(2, 2), pack(3, 3), pack(4, 4), pack(5, 5), pack(6, 6)}, + TIMESTAMP_WITH_TIME_ZONE()); + const auto values = makeFlatVector({1, 2, 3, 4, 5, 6}); + const auto inputMap = makeMapVector({0}, keys, values); + VectorPtr lookup = makeFlatVector( + std::vector{pack(3, 5)}, TIMESTAMP_WITH_TIME_ZONE()); + testCaching({inputMap, lookup}, makeConstant(3, 1)); + + // Test elementAt with TimestampWithTimeZone values embedded in a complex type + // with caching. + const auto rowKeys = makeRowVector({keys}); + const auto mapOfRowKeys = makeMapVector({0}, rowKeys, values); + lookup = makeRowVector({makeFlatVector( + std::vector{pack(5, 10)}, TIMESTAMP_WITH_TIME_ZONE())}); + testCaching({mapOfRowKeys, lookup}, makeConstant(5, 1)); +}