From e5671c09c03ee3ae55a479c9ce84c76e1c256ab2 Mon Sep 17 00:00:00 2001 From: zhli1142015 Date: Wed, 24 Jul 2024 15:36:35 -0700 Subject: [PATCH] Add min, max Spark aggregate functions (#9868) Summary: There are two semantic differences between Presto and Spark. 1. Nested NULLs are compared as values in Spark and as "unknown value" in Presto. 2. The timestamp type represents a time instant in microsecond precision in Spark, but millisecond precision in Presto. Therefore, we need to implement min and max functions for Spark. In this PR, 1. Move Presto `min` and `max` aggregation function implements to lib folder. 2. Add `getMinFunctionFactory` and `getMaxFunctionFactory` which allow callers to register max & min functions with different behaviors. Pull Request resolved: https://github.com/facebookincubator/velox/pull/9868 Reviewed By: mbasmanova Differential Revision: D60051468 Pulled By: kevinwilfong fbshipit-source-id: 1f056420d6909174a35d336e4e1b413a87ef7665 --- velox/docs/functions/spark/aggregate.rst | 10 + velox/functions/lib/aggregates/CMakeLists.txt | 2 + .../{prestosql => lib}/aggregates/Compare.cpp | 13 +- .../{prestosql => lib}/aggregates/Compare.h | 14 +- .../lib/aggregates/MinMaxAggregateBase.cpp | 641 ++++++++++++++++++ .../lib/aggregates/MinMaxAggregateBase.h | 38 ++ .../prestosql/aggregates/CMakeLists.txt | 1 - .../prestosql/aggregates/MinMaxAggregates.cpp | 589 +--------------- .../aggregates/MinMaxByAggregates.cpp | 14 +- .../sparksql/aggregates/CMakeLists.txt | 1 + .../sparksql/aggregates/MinMaxAggregate.cpp | 92 +++ .../sparksql/aggregates/Register.cpp | 5 + .../sparksql/aggregates/tests/CMakeLists.txt | 2 + .../tests/MinMaxAggregationTest.cpp | 300 ++++++++ .../fuzzer/SparkAggregationFuzzerTest.cpp | 4 + velox/type/Timestamp.h | 14 + 16 files changed, 1157 insertions(+), 583 deletions(-) rename velox/functions/{prestosql => lib}/aggregates/Compare.cpp (78%) rename velox/functions/{prestosql => lib}/aggregates/Compare.h (70%) create mode 100644 velox/functions/lib/aggregates/MinMaxAggregateBase.cpp create mode 100644 velox/functions/lib/aggregates/MinMaxAggregateBase.h create mode 100644 velox/functions/sparksql/aggregates/MinMaxAggregate.cpp create mode 100644 velox/functions/sparksql/aggregates/tests/MinMaxAggregationTest.cpp diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index 4056db9659fb..a4ad0ff5dd61 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -80,6 +80,11 @@ General Aggregate Functions Returns the last non-null value of `x`. +.. spark:function:: max(x) -> [same as x] + + Returns the maximum value of ``x``. + ``x`` must be an orderable type. + .. spark:function:: max_by(x, y) -> [same as x] Returns the value of `x` associated with the maximum value of `y`. @@ -97,6 +102,11 @@ General Aggregate Functions Returns c +.. spark:function:: min(x) -> [same as x] + + Returns the minimum value of ``x``. + ``x`` must be an orderable type. + .. spark:function:: min_by(x, y) -> [same as x] Returns the value of `x` associated with the minimum value of `y`. diff --git a/velox/functions/lib/aggregates/CMakeLists.txt b/velox/functions/lib/aggregates/CMakeLists.txt index e3d04ab8333d..4322a1ac5720 100644 --- a/velox/functions/lib/aggregates/CMakeLists.txt +++ b/velox/functions/lib/aggregates/CMakeLists.txt @@ -16,6 +16,8 @@ velox_add_library( velox_functions_aggregates AverageAggregateBase.cpp CentralMomentsAggregatesBase.cpp + Compare.cpp + MinMaxAggregateBase.cpp SingleValueAccumulator.cpp ValueList.cpp ValueSet.cpp) diff --git a/velox/functions/prestosql/aggregates/Compare.cpp b/velox/functions/lib/aggregates/Compare.cpp similarity index 78% rename from velox/functions/prestosql/aggregates/Compare.cpp rename to velox/functions/lib/aggregates/Compare.cpp index 3f784f453147..5595be2105c1 100644 --- a/velox/functions/prestosql/aggregates/Compare.cpp +++ b/velox/functions/lib/aggregates/Compare.cpp @@ -14,21 +14,20 @@ * limitations under the License. */ -#include "velox/functions/prestosql/aggregates/Compare.h" +#include "velox/functions/lib/aggregates/Compare.h" -using namespace facebook::velox::functions::aggregate; - -namespace facebook::velox::aggregate::prestosql { +namespace facebook::velox::functions::aggregate { int32_t compare( const SingleValueAccumulator* accumulator, const DecodedVector& decoded, - vector_size_t index) { + vector_size_t index, + CompareFlags::NullHandlingMode nullHandlingMode) { static const CompareFlags kCompareFlags{ true, // nullsFirst true, // ascending false, // equalsOnly - CompareFlags::NullHandlingMode::kNullAsIndeterminate}; + nullHandlingMode}; auto result = accumulator->compare(decoded, index, kCompareFlags); VELOX_USER_CHECK( @@ -38,4 +37,4 @@ int32_t compare( mapTypeKindToName(decoded.base()->typeKind()))); return result.value(); } -} // namespace facebook::velox::aggregate::prestosql +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/prestosql/aggregates/Compare.h b/velox/functions/lib/aggregates/Compare.h similarity index 70% rename from velox/functions/prestosql/aggregates/Compare.h rename to velox/functions/lib/aggregates/Compare.h index 8b2d7d52c5fc..7b648d09dd08 100644 --- a/velox/functions/prestosql/aggregates/Compare.h +++ b/velox/functions/lib/aggregates/Compare.h @@ -20,17 +20,21 @@ #include "velox/functions/lib/aggregates/SingleValueAccumulator.h" #include "velox/vector/DecodedVector.h" -namespace facebook::velox::aggregate::prestosql { +namespace facebook::velox::functions::aggregate { /// Compare the new value of the DecodedVector at the given index with the value /// stored in the SingleValueAccumulator. Returns 0 if stored and new values are /// equal; <0 if stored value is less then new value; >0 if stored value is /// greater than new value. /// -/// The default nullHandlingMode in Presto is StopAtNull so it will throw an -/// exception when complex type values contain nulls. +/// If nullHandlingMode is NullAsValue, nested nulls are handled as value. If +/// nullHandlingMode is StopAtNull, it will throw an exception when complex +/// type values contain nulls. +/// Note, The default nullHandlingMode in Presto is StopAtNull while the +/// default nullHandlingMode is NullAsValue in Spark. int32_t compare( const velox::functions::aggregate::SingleValueAccumulator* accumulator, const DecodedVector& decoded, - vector_size_t index); -} // namespace facebook::velox::aggregate::prestosql + vector_size_t index, + CompareFlags::NullHandlingMode nullHandlingMode); +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/lib/aggregates/MinMaxAggregateBase.cpp b/velox/functions/lib/aggregates/MinMaxAggregateBase.cpp new file mode 100644 index 000000000000..ef98d0ab0977 --- /dev/null +++ b/velox/functions/lib/aggregates/MinMaxAggregateBase.cpp @@ -0,0 +1,641 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/lib/aggregates/MinMaxAggregateBase.h" + +#include +#include "velox/exec/AggregationHook.h" +#include "velox/functions/lib/CheckNestedNulls.h" +#include "velox/functions/lib/aggregates/Compare.h" +#include "velox/functions/lib/aggregates/SimpleNumericAggregate.h" +#include "velox/functions/lib/aggregates/SingleValueAccumulator.h" +#include "velox/type/FloatingPointUtil.h" + +namespace facebook::velox::functions::aggregate { + +namespace { + +template +struct MinMaxTrait : public std::numeric_limits {}; + +template +class MinMaxAggregate : public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinMaxAggregate(TypePtr resultType, TimestampPrecision precision) + : BaseAggregate(resultType), timestampPrecision_(precision) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(T); + } + + int32_t accumulatorAlignmentSize() const override { + if constexpr (std::is_same_v) { + // Override 'accumulatorAlignmentSize' for UnscaledLongDecimal values as + // it uses int128_t type. Some CPUs don't support misaligned access to + // int128_t type. + return static_cast(sizeof(int128_t)); + } else { + return 1; + } + } + + bool supportsToIntermediate() const override { + return true; + } + + void toIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const override { + this->singleInputAsIntermediate(rows, args, result); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + if constexpr (std::is_same_v) { + // Truncate timestamps to corresponding precision. + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + auto ts = + *BaseAggregate::Aggregate::template value(group); + return Timestamp::truncate(ts, timestampPrecision_); + }); + } else { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } + + private: + const TimestampPrecision timestampPrecision_; +}; + +template +class MaxAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MaxAggregate( + TypePtr resultType, + TimestampPrecision precision = TimestampPrecision::kMilliseconds) + : MinMaxAggregate(resultType, precision) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + // Re-enable pushdown for TIMESTAMP after + // https://github.com/facebookincubator/velox/issues/6297 is fixed. + if (args[0]->typeKind() == TypeKind::TIMESTAMP) { + mayPushdown = false; + } + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, rows, args[0], updateGroup, mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + updateGroup, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + protected: + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + static inline void updateGroup(T& result, T value) { + if constexpr (std::is_floating_point_v) { + if (util::floating_point::NaNAwareLessThan{}(result, value)) { + result = value; + } + } else { + if (result < value) { + result = value; + } + } + } + + private: + static const T kInitialValue_; +}; + +template +const T MaxAggregate::kInitialValue_ = MinMaxTrait::lowest(); + +// Negative INF is the smallest value of floating point type. +template <> +const float MaxAggregate::kInitialValue_ = + -1 * MinMaxTrait::infinity(); + +template <> +const double MaxAggregate::kInitialValue_ = + -1 * MinMaxTrait::infinity(); + +template +class MinAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinAggregate( + TypePtr resultType, + TimestampPrecision precision = TimestampPrecision::kMilliseconds) + : MinMaxAggregate(resultType, precision) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + // Re-enable pushdown for TIMESTAMP after + // https://github.com/facebookincubator/velox/issues/6297 is fixed. + if (args[0]->typeKind() == TypeKind::TIMESTAMP) { + mayPushdown = false; + } + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, rows, args[0], updateGroup, mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + updateGroup, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + protected: + static inline void updateGroup(T& result, T value) { + if constexpr (std::is_floating_point_v) { + if (util::floating_point::NaNAwareGreaterThan{}(result, value)) { + result = value; + } + } else { + if (result > value) { + result = value; + } + } + } + + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + private: + static const T kInitialValue_; +}; + +template +const T MinAggregate::kInitialValue_ = MinMaxTrait::max(); + +// In velox, NaN is considered larger than infinity for floating point types. +template <> +const float MinAggregate::kInitialValue_ = + MinMaxTrait::quiet_NaN(); + +template <> +const double MinAggregate::kInitialValue_ = + MinMaxTrait::quiet_NaN(); + +class NonNumericMinMaxAggregateBase : public exec::Aggregate { + public: + explicit NonNumericMinMaxAggregateBase( + const TypePtr& resultType, + bool throwOnNestedNulls) + : exec::Aggregate(resultType), throwOnNestedNulls_(throwOnNestedNulls) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(SingleValueAccumulator); + } + + bool supportsToIntermediate() const override { + return true; + } + + void toIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const override { + const auto& input = args[0]; + + if (throwOnNestedNulls_) { + DecodedVector decoded(*input, rows, true); + auto indices = decoded.indices(); + rows.applyToSelected([&](vector_size_t i) { + velox::functions::checkNestedNulls( + decoded, indices, i, throwOnNestedNulls_); + }); + } + + if (rows.isAllSelected()) { + result = input; + return; + } + + auto* pool = allocator_->pool(); + + // Set result to NULL for rows that are masked out. + BufferPtr nulls = allocateNulls(rows.size(), pool, bits::kNull); + rows.clearNulls(nulls); + + BufferPtr indices = allocateIndices(rows.size(), pool); + auto* rawIndices = indices->asMutable(); + std::iota(rawIndices, rawIndices + rows.size(), 0); + + result = BaseVector::wrapInDictionary(nulls, indices, rows.size(), input); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + (*result)->resize(numGroups); + + uint64_t* rawNulls = nullptr; + if ((*result)->mayHaveNulls()) { + BufferPtr& nulls = (*result)->mutableNulls((*result)->size()); + rawNulls = nulls->asMutable(); + } + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto accumulator = value(group); + if (!accumulator->hasValue()) { + (*result)->setNull(i, true); + } else { + if (rawNulls) { + bits::clearBit(rawNulls, i); + } + accumulator->read(*result, i); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + // partial and final aggregations are the same + extractValues(groups, numGroups, result); + } + + protected: + template < + typename TCompareTest, + CompareFlags::NullHandlingMode nullHandlingMode> + void doUpdate( + char** groups, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping() && decoded.isNullAt(0)) { + // nothing to do; all values are nulls + return; + } + + rows.applyToSelected([&](vector_size_t i) { + if (velox::functions::checkNestedNulls( + decoded, indices, i, throwOnNestedNulls_)) { + return; + } + + auto accumulator = value(groups[i]); + if (!accumulator->hasValue() || + compareTest(compare(accumulator, decoded, i, nullHandlingMode))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } + + template < + typename TCompareTest, + CompareFlags::NullHandlingMode nullHandlingMode> + void doUpdateSingleGroup( + char* group, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping()) { + if (velox::functions::checkNestedNulls( + decoded, indices, 0, throwOnNestedNulls_)) { + return; + } + + auto accumulator = value(group); + if (!accumulator->hasValue() || + compareTest(compare(accumulator, decoded, 0, nullHandlingMode))) { + accumulator->write(baseVector, indices[0], allocator_); + } + return; + } + + auto accumulator = value(group); + rows.applyToSelected([&](vector_size_t i) { + if (velox::functions::checkNestedNulls( + decoded, indices, i, throwOnNestedNulls_)) { + return; + } + if (!accumulator->hasValue() || + compareTest(compare(accumulator, decoded, i, nullHandlingMode))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } + + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) SingleValueAccumulator(); + } + } + + void destroyInternal(folly::Range groups) override { + for (auto group : groups) { + if (isInitialized(group)) { + value(group)->destroy(allocator_); + } + } + } + + private: + const bool throwOnNestedNulls_; +}; + +template +class NonNumericMaxAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMaxAggregate( + const TypePtr& resultType, + bool throwOnNestedNulls) + : NonNumericMinMaxAggregateBase(resultType, throwOnNestedNulls) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate, nullHandlingMode>( + groups, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup, nullHandlingMode>( + group, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +template +class NonNumericMinAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMinAggregate( + const TypePtr& resultType, + bool throwOnNestedNulls) + : NonNumericMinMaxAggregateBase(resultType, throwOnNestedNulls) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate, nullHandlingMode>( + groups, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup, nullHandlingMode>( + group, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +template < + template + class TNumeric, + template + typename TNonNumeric> +exec::AggregateFunctionFactory getMinMaxFunctionFactoryInternal( + const std::string& name, + CompareFlags::NullHandlingMode nullHandlingMode, + TimestampPrecision precision) { + auto factory = [name, nullHandlingMode, precision]( + core::AggregationNode::Step step, + std::vector argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + auto inputType = argTypes[0]; + switch (inputType->kind()) { + case TypeKind::BOOLEAN: + return std::make_unique>(resultType); + case TypeKind::TINYINT: + return std::make_unique>(resultType); + case TypeKind::SMALLINT: + return std::make_unique>(resultType); + case TypeKind::INTEGER: + return std::make_unique>(resultType); + case TypeKind::BIGINT: + return std::make_unique>(resultType); + case TypeKind::REAL: + return std::make_unique>(resultType); + case TypeKind::DOUBLE: + return std::make_unique>(resultType); + case TypeKind::TIMESTAMP: + return std::make_unique>(resultType, precision); + case TypeKind::HUGEINT: + return std::make_unique>(resultType); + case TypeKind::VARBINARY: + [[fallthrough]]; + case TypeKind::VARCHAR: + return std::make_unique< + TNonNumeric>( + inputType, false); + case TypeKind::ARRAY: + [[fallthrough]]; + case TypeKind::ROW: + if (nullHandlingMode == CompareFlags::NullHandlingMode::kNullAsValue) { + return std::make_unique< + TNonNumeric>( + inputType, false); + } else { + return std::make_unique>( + inputType, true); + } + case TypeKind::UNKNOWN: + return std::make_unique>(resultType); + default: + VELOX_UNREACHABLE( + "Unknown input type for {} aggregation {}", + name, + inputType->kindName()); + } + }; + return factory; +} + +} // namespace + +exec::AggregateFunctionFactory getMinFunctionFactory( + const std::string& name, + CompareFlags::NullHandlingMode nullHandlingMode, + TimestampPrecision precision) { + return getMinMaxFunctionFactoryInternal( + name, nullHandlingMode, precision); +} + +exec::AggregateFunctionFactory getMaxFunctionFactory( + const std::string& name, + CompareFlags::NullHandlingMode nullHandlingMode, + TimestampPrecision precision) { + return getMinMaxFunctionFactoryInternal( + name, nullHandlingMode, precision); +} +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/lib/aggregates/MinMaxAggregateBase.h b/velox/functions/lib/aggregates/MinMaxAggregateBase.h new file mode 100644 index 000000000000..d7524bda95bd --- /dev/null +++ b/velox/functions/lib/aggregates/MinMaxAggregateBase.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/Aggregate.h" + +namespace facebook::velox::functions::aggregate { + +/// Min and max functions in Presto and Spark have different semantics: +/// 1. Nested NULLs are compared as values in Spark and as "unknown value" in +/// Presto. +/// 2. The timestamp type represents a time instant in microsecond precision in +/// Spark, but millisecond precision in Presto. +/// Parameters 'nullHandlingMode' and 'precision' allow to register min and max +/// functions with different behaviors. +exec::AggregateFunctionFactory getMinFunctionFactory( + const std::string& name, + CompareFlags::NullHandlingMode nullHandlingMode, + TimestampPrecision precision); + +exec::AggregateFunctionFactory getMaxFunctionFactory( + const std::string& name, + CompareFlags::NullHandlingMode nullHandlingMode, + TimestampPrecision precision); +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index 49b680bb61b5..4e467bab1f66 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -24,7 +24,6 @@ velox_add_library( BitwiseXorAggregate.cpp BoolAggregates.cpp CentralMomentsAggregates.cpp - Compare.cpp CountAggregate.cpp CountIfAggregate.cpp CovarianceAggregates.cpp diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index 3b6c16e63746..871a62733d33 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -17,11 +17,10 @@ #include #include "velox/exec/Aggregate.h" #include "velox/exec/AggregationHook.h" -#include "velox/functions/lib/CheckNestedNulls.h" +#include "velox/functions/lib/aggregates/MinMaxAggregateBase.h" #include "velox/functions/lib/aggregates/SimpleNumericAggregate.h" #include "velox/functions/lib/aggregates/SingleValueAccumulator.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/functions/prestosql/aggregates/Compare.h" #include "velox/type/FloatingPointUtil.h" using namespace facebook::velox::functions::aggregate; @@ -30,519 +29,6 @@ namespace facebook::velox::aggregate::prestosql { namespace { -template -struct MinMaxTrait : public std::numeric_limits {}; - -template -class MinMaxAggregate : public SimpleNumericAggregate { - using BaseAggregate = SimpleNumericAggregate; - - public: - explicit MinMaxAggregate(TypePtr resultType) : BaseAggregate(resultType) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(T); - } - - int32_t accumulatorAlignmentSize() const override { - return 1; - } - - bool supportsToIntermediate() const override { - return true; - } - - void toIntermediate( - const SelectivityVector& rows, - std::vector& args, - VectorPtr& result) const override { - this->singleInputAsIntermediate(rows, args, result); - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - BaseAggregate::template doExtractValues( - groups, numGroups, result, [&](char* group) { - return *BaseAggregate::Aggregate::template value(group); - }); - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - BaseAggregate::template doExtractValues( - groups, numGroups, result, [&](char* group) { - return *BaseAggregate::Aggregate::template value(group); - }); - } -}; - -/// Override 'accumulatorAlignmentSize' for UnscaledLongDecimal values as it -/// uses int128_t type. Some CPUs don't support misaligned access to int128_t -/// type. -template <> -inline int32_t MinMaxAggregate::accumulatorAlignmentSize() const { - return static_cast(sizeof(int128_t)); -} - -// Truncate timestamps to milliseconds precision. -template <> -void MinMaxAggregate::extractValues( - char** groups, - int32_t numGroups, - VectorPtr* result) { - BaseAggregate::template doExtractValues( - groups, numGroups, result, [&](char* group) { - auto ts = *BaseAggregate::Aggregate::template value(group); - return Timestamp::fromMillis(ts.toMillis()); - }); -} - -template -class MaxAggregate : public MinMaxAggregate { - using BaseAggregate = SimpleNumericAggregate; - - public: - explicit MaxAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - // Re-enable pushdown for TIMESTAMP after - // https://github.com/facebookincubator/velox/issues/6297 is fixed. - if (args[0]->typeKind() == TypeKind::TIMESTAMP) { - mayPushdown = false; - } - if (mayPushdown && args[0]->isLazy()) { - BaseAggregate::template pushdown>( - groups, rows, args[0]); - return; - } - BaseAggregate::template updateGroups( - groups, rows, args[0], updateGroup, mayPushdown); - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addRawInput(groups, rows, args, mayPushdown); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - BaseAggregate::updateOneGroup( - group, - rows, - args[0], - updateGroup, - [](T& result, T value, int /* unused */) { result = value; }, - mayPushdown, - kInitialValue_); - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addSingleGroupRawInput(group, rows, args, mayPushdown); - } - - protected: - void initializeNewGroupsInternal( - char** groups, - folly::Range indices) override { - exec::Aggregate::setAllNulls(groups, indices); - for (auto i : indices) { - *exec::Aggregate::value(groups[i]) = kInitialValue_; - } - } - - static inline void updateGroup(T& result, T value) { - if constexpr (std::is_floating_point_v) { - if (util::floating_point::NaNAwareLessThan{}(result, value)) { - result = value; - } - } else { - if (result < value) { - result = value; - } - } - } - - private: - static const T kInitialValue_; -}; - -template -const T MaxAggregate::kInitialValue_ = MinMaxTrait::lowest(); - -// Negative INF is the smallest value of floating point type. -template <> -const float MaxAggregate::kInitialValue_ = - -1 * MinMaxTrait::infinity(); - -template <> -const double MaxAggregate::kInitialValue_ = - -1 * MinMaxTrait::infinity(); - -template -class MinAggregate : public MinMaxAggregate { - using BaseAggregate = SimpleNumericAggregate; - - public: - explicit MinAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - // Re-enable pushdown for TIMESTAMP after - // https://github.com/facebookincubator/velox/issues/6297 is fixed. - if (args[0]->typeKind() == TypeKind::TIMESTAMP) { - mayPushdown = false; - } - if (mayPushdown && args[0]->isLazy()) { - BaseAggregate::template pushdown>( - groups, rows, args[0]); - return; - } - BaseAggregate::template updateGroups( - groups, rows, args[0], updateGroup, mayPushdown); - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addRawInput(groups, rows, args, mayPushdown); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - BaseAggregate::updateOneGroup( - group, - rows, - args[0], - updateGroup, - [](T& result, T value, int /* unused */) { result = value; }, - mayPushdown, - kInitialValue_); - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addSingleGroupRawInput(group, rows, args, mayPushdown); - } - - protected: - static inline void updateGroup(T& result, T value) { - if constexpr (std::is_floating_point_v) { - if (util::floating_point::NaNAwareGreaterThan{}(result, value)) { - result = value; - } - } else { - if (result > value) { - result = value; - } - } - } - - void initializeNewGroupsInternal( - char** groups, - folly::Range indices) override { - exec::Aggregate::setAllNulls(groups, indices); - for (auto i : indices) { - *exec::Aggregate::value(groups[i]) = kInitialValue_; - } - } - - private: - static const T kInitialValue_; -}; - -template -const T MinAggregate::kInitialValue_ = MinMaxTrait::max(); - -// In velox, NaN is considered larger than infinity for floating point types. -template <> -const float MinAggregate::kInitialValue_ = - MinMaxTrait::quiet_NaN(); - -template <> -const double MinAggregate::kInitialValue_ = - MinMaxTrait::quiet_NaN(); - -class NonNumericMinMaxAggregateBase : public exec::Aggregate { - public: - explicit NonNumericMinMaxAggregateBase( - const TypePtr& resultType, - bool throwOnNestedNulls) - : exec::Aggregate(resultType), throwOnNestedNulls_(throwOnNestedNulls) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(SingleValueAccumulator); - } - - bool supportsToIntermediate() const override { - return true; - } - - void toIntermediate( - const SelectivityVector& rows, - std::vector& args, - VectorPtr& result) const override { - const auto& input = args[0]; - - if (throwOnNestedNulls_) { - DecodedVector decoded(*input, rows, true); - auto indices = decoded.indices(); - rows.applyToSelected([&](vector_size_t i) { - velox::functions::checkNestedNulls( - decoded, indices, i, throwOnNestedNulls_); - }); - } - - if (rows.isAllSelected()) { - result = input; - return; - } - - auto* pool = allocator_->pool(); - - // Set result to NULL for rows that are masked out. - BufferPtr nulls = allocateNulls(rows.size(), pool, bits::kNull); - rows.clearNulls(nulls); - - BufferPtr indices = allocateIndices(rows.size(), pool); - auto* rawIndices = indices->asMutable(); - std::iota(rawIndices, rawIndices + rows.size(), 0); - - result = BaseVector::wrapInDictionary(nulls, indices, rows.size(), input); - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - VELOX_CHECK(result); - (*result)->resize(numGroups); - - uint64_t* rawNulls = nullptr; - if ((*result)->mayHaveNulls()) { - BufferPtr& nulls = (*result)->mutableNulls((*result)->size()); - rawNulls = nulls->asMutable(); - } - - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - auto accumulator = value(group); - if (!accumulator->hasValue()) { - (*result)->setNull(i, true); - } else { - if (rawNulls) { - bits::clearBit(rawNulls, i); - } - accumulator->read(*result, i); - } - } - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - // partial and final aggregations are the same - extractValues(groups, numGroups, result); - } - - protected: - template - void doUpdate( - char** groups, - const SelectivityVector& rows, - const VectorPtr& arg, - TCompareTest compareTest) { - DecodedVector decoded(*arg, rows, true); - auto indices = decoded.indices(); - auto baseVector = decoded.base(); - - if (decoded.isConstantMapping() && decoded.isNullAt(0)) { - // nothing to do; all values are nulls - return; - } - - rows.applyToSelected([&](vector_size_t i) { - if (velox::functions::checkNestedNulls( - decoded, indices, i, throwOnNestedNulls_)) { - return; - } - - auto accumulator = value(groups[i]); - if (!accumulator->hasValue() || - compareTest(compare(accumulator, decoded, i))) { - accumulator->write(baseVector, indices[i], allocator_); - } - }); - } - - template - void doUpdateSingleGroup( - char* group, - const SelectivityVector& rows, - const VectorPtr& arg, - TCompareTest compareTest) { - DecodedVector decoded(*arg, rows, true); - auto indices = decoded.indices(); - auto baseVector = decoded.base(); - - if (decoded.isConstantMapping()) { - if (velox::functions::checkNestedNulls( - decoded, indices, 0, throwOnNestedNulls_)) { - return; - } - - auto accumulator = value(group); - if (!accumulator->hasValue() || - compareTest(compare(accumulator, decoded, 0))) { - accumulator->write(baseVector, indices[0], allocator_); - } - return; - } - - auto accumulator = value(group); - rows.applyToSelected([&](vector_size_t i) { - if (velox::functions::checkNestedNulls( - decoded, indices, i, throwOnNestedNulls_)) { - return; - } - - if (!accumulator->hasValue() || - compareTest(compare(accumulator, decoded, i))) { - accumulator->write(baseVector, indices[i], allocator_); - } - }); - } - - void initializeNewGroupsInternal( - char** groups, - folly::Range indices) override { - exec::Aggregate::setAllNulls(groups, indices); - for (auto i : indices) { - new (groups[i] + offset_) SingleValueAccumulator(); - } - } - - void destroyInternal(folly::Range groups) override { - for (auto group : groups) { - if (isInitialized(group)) { - value(group)->destroy(allocator_); - } - } - } - - private: - const bool throwOnNestedNulls_; -}; - -class NonNumericMaxAggregate : public NonNumericMinMaxAggregateBase { - public: - explicit NonNumericMaxAggregate( - const TypePtr& resultType, - bool throwOnNestedNulls) - : NonNumericMinMaxAggregateBase(resultType, throwOnNestedNulls) {} - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - doUpdate(groups, rows, args[0], [](int32_t compareResult) { - return compareResult < 0; - }); - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addRawInput(groups, rows, args, mayPushdown); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { - return compareResult < 0; - }); - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addSingleGroupRawInput(group, rows, args, mayPushdown); - } -}; - -class NonNumericMinAggregate : public NonNumericMinMaxAggregateBase { - public: - explicit NonNumericMinAggregate( - const TypePtr& resultType, - bool throwOnNestedNulls) - : NonNumericMinMaxAggregateBase(resultType, throwOnNestedNulls) {} - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - doUpdate(groups, rows, args[0], [](int32_t compareResult) { - return compareResult > 0; - }); - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addRawInput(groups, rows, args, mayPushdown); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { - return compareResult > 0; - }); - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool mayPushdown) override { - addSingleGroupRawInput(group, rows, args, mayPushdown); - } -}; - std::pair rawOffsetAndSizes( ArrayVector& arrayVector) { return { @@ -932,16 +418,12 @@ class MaxNAggregate : public MinMaxNAggregateBase> { : MinMaxNAggregateBase>(resultType) {} }; -template < - template - class TNumeric, - typename TNonNumeric, - template - class TNumericN> +template