From 33994cd94158444a797081a1515b7b610caba83b Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Mon, 11 Mar 2024 12:06:16 -0700 Subject: [PATCH] Add skewness Spark agg function (#7513) Summary: There are some inconsistencies between the skewness calculations in Spark and Presto. In Presto, the skewness calculation requires `count >= 3` to produce a result, whereas in Spark, `count >= 1` is required. Additionally, Spark also has a requirement for `m2 != 0`. Therefore, it is necessary to move `CentralMomentsAggregates` to the `functions/lib` directory for reuse by both Spark and Presto. Spark and Presto can then implement their own respective `SkewnessResultAccessor`. Spark skewness: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L291-L309 In addition, the algorithm for calculating kurtosis in Spark is different from Presto, so currently they cannot be reused. However, there are plans to continue working on adapting it in the future. Pull Request resolved: https://github.com/facebookincubator/velox/pull/7513 Reviewed By: pedroerp Differential Revision: D54699558 Pulled By: Yuhta fbshipit-source-id: 1e9cbaecabd59d98b706d9a7de1c7bb747cbd9d4 --- velox/docs/functions/spark/aggregate.rst | 6 + velox/functions/lib/aggregates/CMakeLists.txt | 6 +- .../CentralMomentsAggregatesBase.cpp | 52 ++ .../aggregates/CentralMomentsAggregatesBase.h | 448 +++++++++++++++++ .../aggregates/CentralMomentsAggregates.cpp | 470 +----------------- .../sparksql/aggregates/CMakeLists.txt | 1 + .../aggregates/CentralMomentsAggregate.cpp | 114 +++++ .../aggregates/CentralMomentsAggregate.h | 30 ++ .../sparksql/aggregates/Register.cpp | 2 + .../sparksql/aggregates/tests/CMakeLists.txt | 3 +- .../tests/CentralMomentsAggregationTest.cpp | 60 +++ .../fuzzer/SparkAggregationFuzzerTest.cpp | 7 +- 12 files changed, 739 insertions(+), 460 deletions(-) create mode 100644 velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp create mode 100644 velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h create mode 100644 velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp create mode 100644 velox/functions/sparksql/aggregates/CentralMomentsAggregate.h create mode 100644 velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index a43c95042aca1..c09b42fb78ca2 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -103,6 +103,12 @@ General Aggregate Functions Returns b +.. spark:function:: skewness(x) -> double + + Returns the skewness of all input values. When the count of `x` is greater than or equal to 1, + a non-null output will be generated. When the value of `m2` in the accumulator is 0, a null + output will be generated. + .. spark:function:: sum(x) -> bigint|double|real Returns the sum of `x`. diff --git a/velox/functions/lib/aggregates/CMakeLists.txt b/velox/functions/lib/aggregates/CMakeLists.txt index 19e3eee7de655..1eabebce1f16d 100644 --- a/velox/functions/lib/aggregates/CMakeLists.txt +++ b/velox/functions/lib/aggregates/CMakeLists.txt @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_functions_aggregates SingleValueAccumulator.cpp - AverageAggregateBase.cpp ValueSet.cpp) +add_library( + velox_functions_aggregates + AverageAggregateBase.cpp CentralMomentsAggregatesBase.cpp + SingleValueAccumulator.cpp ValueSet.cpp) target_link_libraries(velox_functions_aggregates velox_exec velox_presto_serializer Folly::folly) diff --git a/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp new file mode 100644 index 0000000000000..4e972fb24795a --- /dev/null +++ b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" + +namespace facebook::velox::functions::aggregate { + +void checkAccumulatorRowType( + const TypePtr& type, + const std::string& errorMessage) { + VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.count)->kind(), + TypeKind::BIGINT, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m1)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m2)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m3)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); + VELOX_CHECK_EQ( + type->childAt(kCentralMomentsIndices.m4)->kind(), + TypeKind::DOUBLE, + "{}", + errorMessage); +} + +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h new file mode 100644 index 0000000000000..7b0191db235ef --- /dev/null +++ b/velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h @@ -0,0 +1,448 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/Aggregate.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::aggregate { + +// Indices into RowType representing intermediate results of skewness and +// kurtosis. Columns appear in alphabetical order. +struct CentralMomentsIndices { + int32_t count; + int32_t m1; + int32_t m2; + int32_t m3; + int32_t m4; +}; +constexpr CentralMomentsIndices kCentralMomentsIndices{0, 1, 2, 3, 4}; + +struct CentralMomentsAccumulator { + int64_t count() const { + return count_; + } + + double m1() const { + return m1_; + } + + double m2() const { + return m2_; + } + + double m3() const { + return m3_; + } + + double m4() const { + return m4_; + } + + void update(double value) { + double oldCount = count(); + count_ += 1; + double oldM1 = m1(); + double oldM2 = m2(); + double oldM3 = m3(); + double delta = value - oldM1; + double deltaN = delta / count(); + double deltaN2 = deltaN * deltaN; + double dm2 = delta * deltaN * oldCount; + + m1_ += deltaN; + m2_ += dm2; + m3_ += dm2 * deltaN * (count() - 2) - 3 * deltaN * oldM2; + m4_ += dm2 * deltaN2 * (1.0 * count() * count() - 3.0 * count() + 3) + + 6 * deltaN2 * oldM2 - 4 * deltaN * oldM3; + } + + inline void merge(const CentralMomentsAccumulator& other) { + merge(other.count(), other.m1(), other.m2(), other.m3(), other.m4()); + } + + void merge( + double otherCount, + double otherM1, + double otherM2, + double otherM3, + double otherM4) { + if (otherCount == 0) { + return; + } + if (count_ == 0) { + count_ = otherCount; + m1_ = otherM1; + m2_ = otherM2; + m3_ = otherM3; + m4_ = otherM4; + return; + } + + double oldCount = count(); + count_ += otherCount; + + double oldM1 = m1(); + double oldM2 = m2(); + double oldM3 = m3(); + double delta = otherM1 - oldM1; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m1_ = (oldCount * oldM1 + otherCount * otherM1) / count(); + m2_ += otherM2 + delta2 * oldCount * otherCount / count(); + m3_ += otherM3 + + delta3 * oldCount * otherCount * (oldCount - otherCount) / + (1.0 * count() * count()) + + 3 * delta * (oldCount * otherM2 - otherCount * oldM2) / count(); + m4_ += otherM4 + + delta4 * oldCount * otherCount * + (oldCount * oldCount - oldCount * otherCount + + otherCount * otherCount) / + (1.0 * count() * count() * count()) + + 6 * delta2 * + (oldCount * oldCount * otherM2 + otherCount * otherCount * oldM2) / + (1.0 * count() * count()) + + 4 * delta * (oldCount * otherM3 - otherCount * oldM3) / count(); + } + + private: + int64_t count_{0}; + double m1_{0}; + double m2_{0}; + double m3_{0}; + double m4_{0}; +}; + +template +SimpleVector* asSimpleVector( + const RowVector* rowVector, + int32_t childIndex) { + auto result = rowVector->childAt(childIndex)->as>(); + VELOX_CHECK_NOT_NULL(result); + return result; +} + +class CentralMomentsIntermediateInput { + public: + explicit CentralMomentsIntermediateInput( + const RowVector* rowVector, + const CentralMomentsIndices& indices = kCentralMomentsIndices) + : count_{asSimpleVector(rowVector, indices.count)}, + m1_{asSimpleVector(rowVector, indices.m1)}, + m2_{asSimpleVector(rowVector, indices.m2)}, + m3_{asSimpleVector(rowVector, indices.m3)}, + m4_{asSimpleVector(rowVector, indices.m4)} {} + + void mergeInto(CentralMomentsAccumulator& accumulator, vector_size_t row) { + accumulator.merge( + count_->valueAt(row), + m1_->valueAt(row), + m2_->valueAt(row), + m3_->valueAt(row), + m4_->valueAt(row)); + } + + protected: + SimpleVector* count_; + SimpleVector* m1_; + SimpleVector* m2_; + SimpleVector* m3_; + SimpleVector* m4_; +}; + +template +T* mutableRawValues(const RowVector* rowVector, int32_t childIndex) { + return rowVector->childAt(childIndex) + ->as>() + ->mutableRawValues(); +} + +class CentralMomentsIntermediateResult { + public: + explicit CentralMomentsIntermediateResult( + const RowVector* rowVector, + const CentralMomentsIndices& indices = kCentralMomentsIndices) + : count_{mutableRawValues(rowVector, indices.count)}, + m1_{mutableRawValues(rowVector, indices.m1)}, + m2_{mutableRawValues(rowVector, indices.m2)}, + m3_{mutableRawValues(rowVector, indices.m3)}, + m4_{mutableRawValues(rowVector, indices.m4)} {} + + static std::string type() { + return "row(bigint,double,double,double,double)"; + } + + void set(vector_size_t row, const CentralMomentsAccumulator& accumulator) { + count_[row] = accumulator.count(); + m1_[row] = accumulator.m1(); + m2_[row] = accumulator.m2(); + m3_[row] = accumulator.m3(); + m4_[row] = accumulator.m4(); + } + + private: + int64_t* count_; + double* m1_; + double* m2_; + double* m3_; + double* m4_; +}; + +// T is the input type for partial aggregation, it can be integer, double or +// float. Not used for final aggregation. TResultAccessor is the type of the +// static struct that will access the result in a certain way from the +// CentralMoments Accumulator. +template +class CentralMomentsAggregatesBase : public exec::Aggregate { + public: + explicit CentralMomentsAggregatesBase(TypePtr resultType) + : exec::Aggregate(resultType) {} + + int32_t accumulatorAlignmentSize() const override { + return alignof(CentralMomentsAccumulator); + } + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(CentralMomentsAccumulator); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) CentralMomentsAccumulator(); + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedRaw_.decode(*args[0], rows); + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected( + [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedRaw_.isNullAt(i)) { + return; + } + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], data[i]); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedRaw_.decode(*args[0], rows); + + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected( + [&](vector_size_t i) { updateNonNullValue(group, value); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedRaw_.isNullAt(i)) { + updateNonNullValue(group, decodedRaw_.valueAt(i)); + } + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + CentralMomentsAccumulator accData; + rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); + updateNonNullValue(group, accData); + } else { + CentralMomentsAccumulator accData; + rows.applyToSelected( + [&](vector_size_t i) { accData.update(decodedRaw_.valueAt(i)); }); + updateNonNullValue(group, accData); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + CentralMomentsIntermediateInput input{baseRowVector}; + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + rows.applyToSelected([&](vector_size_t i) { + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(groups[i]); + input.mergeInto(*accumulator(groups[i]), decodedIndex); + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + CentralMomentsIntermediateInput input{baseRowVector}; + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + CentralMomentsAccumulator accData; + rows.applyToSelected( + [&](vector_size_t i) { input.mergeInto(accData, decodedIndex); }); + updateNonNullValue(group, accData); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + exec::Aggregate::clearNull(group); + input.mergeInto(*accumulator(group), decodedIndex); + }); + } else { + CentralMomentsAccumulator accData; + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + input.mergeInto(accData, decodedIndex); + }); + updateNonNullValue(group, accData); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = getRawNulls(vector); + + double* rawValues = vector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + vector->setNull(i, true); + } else { + auto* accData = accumulator(group); + if (TResultAccessor::hasResult(*accData)) { + clearNull(rawNulls, i); + rawValues[i] = TResultAccessor::result(*accData); + } else { + vector->setNull(i, true); + } + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto rowVector = (*result)->as(); + rowVector->resize(numGroups); + for (auto& child : rowVector->children()) { + child->resize(numGroups); + } + + uint64_t* rawNulls = getRawNulls(rowVector); + + CentralMomentsIntermediateResult centralMomentsResult{rowVector}; + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + rowVector->setNull(i, true); + } else { + clearNull(rawNulls, i); + centralMomentsResult.set(i, *accumulator(group)); + } + } + } + + private: + template + inline void updateNonNullValue(char* group, T value) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + CentralMomentsAccumulator* accData = accumulator(group); + accData->update((double)value); + } + + template + inline void updateNonNullValue( + char* group, + const CentralMomentsAccumulator& accData) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + CentralMomentsAccumulator* thisAccData = accumulator(group); + thisAccData->merge(accData); + } + + inline CentralMomentsAccumulator* accumulator(char* group) { + return exec::Aggregate::value(group); + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; +}; + +void checkAccumulatorRowType( + const TypePtr& type, + const std::string& errorMessage); + +} // namespace facebook::velox::functions::aggregate diff --git a/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp b/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp index 101ce8e124284..1559e8a0f38a5 100644 --- a/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp +++ b/velox/functions/prestosql/aggregates/CentralMomentsAggregates.cpp @@ -13,112 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/Aggregate.h" +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/vector/FlatVector.h" - -namespace facebook::velox::aggregate::prestosql { - -namespace { -// Indices into RowType representing intermediate results of skewness and -// kurtosis. Columns appear in alphabetical order. -struct CentralMomentsIndices { - int32_t count; - int32_t m1; - int32_t m2; - int32_t m3; - int32_t m4; -}; -constexpr CentralMomentsIndices kCentralMomentsIndices{0, 1, 2, 3, 4}; - -struct CentralMomentsAccumulator { - double count() const { - return count_; - } - - double m1() const { - return m1_; - } - - double m2() const { - return m2_; - } - - double m3() const { - return m3_; - } - - double m4() const { - return m4_; - } - - void update(double value) { - double oldCount = count(); - count_ += 1; - double oldM1 = m1(); - double oldM2 = m2(); - double oldM3 = m3(); - double delta = value - oldM1; - double deltaN = delta / count(); - double deltaN2 = deltaN * deltaN; - double dm2 = delta * deltaN * oldCount; - - m1_ += deltaN; - m2_ += dm2; - m3_ += dm2 * deltaN * (count() - 2) - 3 * deltaN * oldM2; - m4_ += dm2 * deltaN2 * (count() * (double)count() - 3 * count() + 3) + - 6 * deltaN2 * oldM2 - 4 * deltaN * oldM3; - } - - inline void merge(const CentralMomentsAccumulator& other) { - merge(other.count(), other.m1(), other.m2(), other.m3(), other.m4()); - } - - void merge( - double otherCount, - double otherM1, - double otherM2, - double otherM3, - double otherM4) { - if (otherCount == 0) { - return; - } - - double oldCount = count(); - count_ += otherCount; - double oldM1 = m1(); - double oldM2 = m2(); - double oldM3 = m3(); - double delta = otherM1 - oldM1; - double delta2 = delta * delta; - double delta3 = delta * delta2; - double delta4 = delta2 * delta2; +using namespace facebook::velox::functions::aggregate; - m1_ = (oldCount * oldM1 + otherCount * otherM1) / count(); - m2_ += otherM2 + delta2 * oldCount * otherCount / count(); - m3_ += otherM3 + - delta3 * oldCount * otherCount * (oldCount - otherCount) / - (count() * count()) + - 3 * delta * (oldCount * otherM2 - otherCount * oldM2) / count(); - m4_ += otherM4 + - delta4 * oldCount * otherCount * - (oldCount * oldCount - oldCount * otherCount + - otherCount * otherCount) / - (count() * count() * count()) + - 6 * delta2 * - (oldCount * oldCount * otherM2 + otherCount * otherCount * oldM2) / - (count() * count()) + - 4 * delta * (oldCount * otherM3 - otherCount * oldM3) / count(); - } - - private: - int64_t count_{0}; - double m1_{0}; - double m2_{0}; - double m3_{0}; - double m4_{0}; -}; +namespace facebook::velox::aggregate::prestosql { struct SkewnessResultAccessor { static bool hasResult(const CentralMomentsAccumulator& accumulator) { @@ -146,350 +48,6 @@ struct KurtosisResultAccessor { } }; -template -SimpleVector* asSimpleVector( - const RowVector* rowVector, - int32_t childIndex) { - auto result = rowVector->childAt(childIndex)->as>(); - VELOX_CHECK_NOT_NULL(result); - return result; -} - -class CentralMomentsIntermediateInput { - public: - explicit CentralMomentsIntermediateInput( - const RowVector* rowVector, - const CentralMomentsIndices& indices = kCentralMomentsIndices) - : count_{asSimpleVector(rowVector, indices.count)}, - m1_{asSimpleVector(rowVector, indices.m1)}, - m2_{asSimpleVector(rowVector, indices.m2)}, - m3_{asSimpleVector(rowVector, indices.m3)}, - m4_{asSimpleVector(rowVector, indices.m4)} {} - - void mergeInto(CentralMomentsAccumulator& accumulator, vector_size_t row) { - accumulator.merge( - count_->valueAt(row), - m1_->valueAt(row), - m2_->valueAt(row), - m3_->valueAt(row), - m4_->valueAt(row)); - } - - protected: - SimpleVector* count_; - SimpleVector* m1_; - SimpleVector* m2_; - SimpleVector* m3_; - SimpleVector* m4_; -}; - -template -T* mutableRawValues(const RowVector* rowVector, int32_t childIndex) { - return rowVector->childAt(childIndex) - ->as>() - ->mutableRawValues(); -} - -class CentralMomentsIntermediateResult { - public: - explicit CentralMomentsIntermediateResult( - const RowVector* rowVector, - const CentralMomentsIndices& indices = kCentralMomentsIndices) - : count_{mutableRawValues(rowVector, indices.count)}, - m1_{mutableRawValues(rowVector, indices.m1)}, - m2_{mutableRawValues(rowVector, indices.m2)}, - m3_{mutableRawValues(rowVector, indices.m3)}, - m4_{mutableRawValues(rowVector, indices.m4)} {} - - static std::string type() { - return "row(bigint,double,double,double,double)"; - } - - void set(vector_size_t row, const CentralMomentsAccumulator& accumulator) { - count_[row] = accumulator.count(); - m1_[row] = accumulator.m1(); - m2_[row] = accumulator.m2(); - m3_[row] = accumulator.m3(); - m4_[row] = accumulator.m4(); - } - - private: - int64_t* count_; - double* m1_; - double* m2_; - double* m3_; - double* m4_; -}; - -// T is the input type for partial aggregation, it can be integer, double or -// float. Not used for final aggregation. TResultAccessor is the type of the -// static struct that will access the result in a certain way from the -// CentralMoments Accumulator. -template -class CentralMomentsAggregate : public exec::Aggregate { - public: - explicit CentralMomentsAggregate(TypePtr resultType) - : exec::Aggregate(resultType) {} - - int32_t accumulatorAlignmentSize() const override { - return alignof(CentralMomentsAccumulator); - } - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(CentralMomentsAccumulator); - } - - void initializeNewGroups( - char** groups, - folly::Range indices) override { - setAllNulls(groups, indices); - for (auto i : indices) { - new (groups[i] + offset_) CentralMomentsAccumulator(); - } - } - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedRaw_.decode(*args[0], rows); - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - auto value = decodedRaw_.valueAt(0); - rows.applyToSelected( - [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedRaw_.isNullAt(i)) { - return; - } - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - auto data = decodedRaw_.data(); - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], data[i]); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedRaw_.decode(*args[0], rows); - - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - auto value = decodedRaw_.valueAt(0); - rows.applyToSelected( - [&](vector_size_t i) { updateNonNullValue(group, value); }); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (!decodedRaw_.isNullAt(i)) { - updateNonNullValue(group, decodedRaw_.valueAt(i)); - } - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - auto data = decodedRaw_.data(); - CentralMomentsAccumulator accData; - rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); - updateNonNullValue(group, accData); - } else { - CentralMomentsAccumulator accData; - rows.applyToSelected( - [&](vector_size_t i) { accData.update(decodedRaw_.valueAt(i)); }); - updateNonNullValue(group, accData); - } - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - CentralMomentsIntermediateInput input{baseRowVector}; - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - rows.applyToSelected([&](vector_size_t i) { - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedPartial_.isNullAt(i)) { - return; - } - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(groups[i]); - input.mergeInto(*accumulator(groups[i]), decodedIndex); - }); - } - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - CentralMomentsIntermediateInput input{baseRowVector}; - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - CentralMomentsAccumulator accData; - rows.applyToSelected( - [&](vector_size_t i) { input.mergeInto(accData, decodedIndex); }); - updateNonNullValue(group, accData); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedPartial_.isNullAt(i)) { - return; - } - auto decodedIndex = decodedPartial_.index(i); - exec::Aggregate::clearNull(group); - input.mergeInto(*accumulator(group), decodedIndex); - }); - } else { - CentralMomentsAccumulator accData; - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - input.mergeInto(accData, decodedIndex); - }); - updateNonNullValue(group, accData); - } - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto vector = (*result)->as>(); - VELOX_CHECK(vector); - vector->resize(numGroups); - uint64_t* rawNulls = getRawNulls(vector); - - double* rawValues = vector->mutableRawValues(); - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - vector->setNull(i, true); - } else { - auto* accData = accumulator(group); - if (TResultAccessor::hasResult(*accData)) { - clearNull(rawNulls, i); - rawValues[i] = TResultAccessor::result(*accData); - } else { - vector->setNull(i, true); - } - } - } - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto rowVector = (*result)->as(); - rowVector->resize(numGroups); - for (auto& child : rowVector->children()) { - child->resize(numGroups); - } - - uint64_t* rawNulls = getRawNulls(rowVector); - - CentralMomentsIntermediateResult centralMomentsResult{rowVector}; - - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - rowVector->setNull(i, true); - } else { - clearNull(rawNulls, i); - centralMomentsResult.set(i, *accumulator(group)); - } - } - } - - private: - template - inline void updateNonNullValue(char* group, T value) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - CentralMomentsAccumulator* accData = accumulator(group); - accData->update((double)value); - } - - template - inline void updateNonNullValue( - char* group, - const CentralMomentsAccumulator& accData) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - CentralMomentsAccumulator* thisAccData = accumulator(group); - thisAccData->merge(accData); - } - - inline CentralMomentsAccumulator* accumulator(char* group) { - return exec::Aggregate::value(group); - } - - DecodedVector decodedRaw_; - DecodedVector decodedPartial_; -}; - -void checkAccumulatorRowType( - const TypePtr& type, - const std::string& errorMessage) { - VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.count)->kind(), - TypeKind::BIGINT, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m1)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m2)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m3)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); - VELOX_CHECK_EQ( - type->childAt(kCentralMomentsIndices.m4)->kind(), - TypeKind::DOUBLE, - "{}", - errorMessage); -} - template exec::AggregateRegistrationResult registerCentralMoments( const std::string& name, @@ -523,22 +81,24 @@ exec::AggregateRegistrationResult registerCentralMoments( switch (inputType->kind()) { case TypeKind::SMALLINT: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::INTEGER: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::BIGINT: return std::make_unique< - CentralMomentsAggregate>( + CentralMomentsAggregatesBase>( resultType); case TypeKind::DOUBLE: return std::make_unique< - CentralMomentsAggregate>(resultType); + CentralMomentsAggregatesBase>( + resultType); case TypeKind::REAL: return std::make_unique< - CentralMomentsAggregate>(resultType); + CentralMomentsAggregatesBase>( + resultType); default: VELOX_UNSUPPORTED( "Unsupported input type: {}. " @@ -550,17 +110,15 @@ exec::AggregateRegistrationResult registerCentralMoments( inputType, "Input type for final aggregation must be " "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); - // final agg not use template T, int64_t here has no effect. - return std::make_unique< - CentralMomentsAggregate>(resultType); + return std::make_unique>(resultType); } }, withCompanionFunctions, overwrite); } -} // namespace - void registerCentralMomentsAggregates( const std::string& prefix, bool withCompanionFunctions, diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 77101bb8717ce..011ff1dfeb398 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -16,6 +16,7 @@ add_library( AverageAggregate.cpp BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp + CentralMomentsAggregate.cpp FirstLastAggregate.cpp MinMaxByAggregate.cpp Register.cpp diff --git a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp new file mode 100644 index 0000000000000..1a8d3a34a2c15 --- /dev/null +++ b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" +#include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +namespace { +struct SkewnessResultAccessor { + static bool hasResult(const CentralMomentsAccumulator& accumulator) { + return accumulator.count() >= 1 && accumulator.m2() != 0; + } + + static double result(const CentralMomentsAccumulator& accumulator) { + return std::sqrt(accumulator.count()) * accumulator.m3() / + std::pow(accumulator.m2(), 1.5); + } +}; + +template +exec::AggregateRegistrationResult registerCentralMoments( + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { + std::vector> signatures; + std::vector inputTypes = { + "smallint", "integer", "bigint", "real", "double"}; + for (const auto& inputType : inputTypes) { + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType(CentralMomentsIntermediateResult::type()) + .argumentType(inputType) + .build()); + } + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + const auto& inputType = argTypes[0]; + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::SMALLINT: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::INTEGER: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::BIGINT: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::DOUBLE: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + case TypeKind::REAL: + return std::make_unique< + CentralMomentsAggregatesBase>( + resultType); + default: + VELOX_UNSUPPORTED( + "Unsupported input type: {}. " + "Expected SMALLINT, INTEGER, BIGINT, DOUBLE or REAL.", + inputType->toString()) + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique>(resultType); + } + }, + withCompanionFunctions, + overwrite); +} +} // namespace + +void registerCentralMomentsAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + registerCentralMoments( + prefix + "skewness", withCompanionFunctions, overwrite); +} + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h new file mode 100644 index 0000000000000..a0b9a69297e3f --- /dev/null +++ b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/exec/Aggregate.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +void registerCentralMomentsAggregate( + const std::string& name, + bool withCompanionFunctions, + bool overwrite); + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index 79e9c076aa1a0..7a2886dff2dee 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -19,6 +19,7 @@ #include "velox/functions/sparksql/aggregates/AverageAggregate.h" #include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h" #include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" +#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" #include "velox/functions/sparksql/aggregates/SumAggregate.h" namespace facebook::velox::functions::aggregate::sparksql { @@ -43,5 +44,6 @@ void registerAggregateFunctions( prefix + "bloom_filter_agg", withCompanionFunctions, overwrite); registerAverage(prefix + "avg", withCompanionFunctions, overwrite); registerSum(prefix + "sum", withCompanionFunctions, overwrite); + registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index 22730f9d7e578..f7a5fdf05cdc1 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -14,11 +14,12 @@ add_executable( velox_functions_spark_aggregates_test + AverageAggregationTest.cpp BitwiseXorAggregationTest.cpp BloomFilterAggAggregateTest.cpp + CentralMomentsAggregationTest.cpp FirstAggregateTest.cpp LastAggregateTest.cpp - AverageAggregationTest.cpp Main.cpp MinMaxByAggregationTest.cpp SumAggregationTest.cpp) diff --git a/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp new file mode 100644 index 0000000000000..1a557f4b5b06b --- /dev/null +++ b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/sparksql/aggregates/Register.h" + +using namespace facebook::velox::exec::test; +using namespace facebook::velox::functions::aggregate::test; + +namespace facebook::velox::functions::aggregate::sparksql::test { + +namespace { +class CentralMomentsAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerAggregateFunctions("spark_"); + } + + void testSkewnessResult( + const RowVectorPtr& input, + const RowVectorPtr& expected) { + PlanBuilder builder(pool()); + builder.values({input}); + builder.singleAggregation({}, {"spark_skewness(c0)"}); + AssertQueryBuilder queryBuilder( + builder.planNode(), this->duckDbQueryRunner_); + queryBuilder.assertResults({expected}); + } +}; + +TEST_F(CentralMomentsAggregationTest, skewnessHasResult) { + auto input = makeRowVector({makeFlatVector({1, 2})}); + // Even when the count is 2, Spark still produces output. + auto expected = + makeRowVector({makeFlatVector(std::vector{0.0})}); + testSkewnessResult(input, expected); + + input = makeRowVector({makeFlatVector({1, 1})}); + expected = makeRowVector({makeNullableFlatVector( + std::vector>{std::nullopt})}); + testSkewnessResult(input, expected); +} + +} // namespace +} // namespace facebook::velox::functions::aggregate::sparksql::test diff --git a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp index c23e4c9368c99..069aed0314056 100644 --- a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp +++ b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp @@ -69,7 +69,8 @@ int main(int argc, char** argv) { {"first", nullptr}, {"first_ignore_null", nullptr}, {"max_by", nullptr}, - {"min_by", nullptr}}; + {"min_by", nullptr}, + {"skewness", nullptr}}; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; auto duckQueryRunner = @@ -78,6 +79,10 @@ int main(int argc, char** argv) { // https://github.com/facebookincubator/velox/issues/7677 "max_by", "min_by", + // The skewness functions of Velox and DuckDB use different + // algorithms. + // https://github.com/facebookincubator/velox/issues/4845 + "skewness", }); using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;