diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 9adbb398bac1..87f91d8c7182 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -237,8 +237,16 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const { } // TODO: add spilling for pre-grouped aggregation later: // https://github.com/facebookincubator/velox/issues/3264 - return (isFinal() || isSingle()) && preGroupedKeys().empty() && - queryConfig.aggregationSpillEnabled(); + if ((isFinal() || isSingle()) && queryConfig.aggregationSpillEnabled()) { + return preGroupedKeys().empty(); + } + + if ((isIntermediate() || isPartial()) && + queryConfig.partialAggregationSpillEnabled()) { + return preGroupedKeys().empty(); + } + + return false; } void AggregationNode::addDetails(std::stringstream& stream) const { diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index a041cdc9aee6..99073e4380e9 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -607,6 +607,14 @@ class AggregationNode : public PlanNode { return step_ == Step::kSingle; } + bool isIntermediate() const { + return step_ == Step::kIntermediate; + } + + bool isPartial() const { + return step_ == Step::kPartial; + } + folly::dynamic serialize() const override; static PlanNodePtr create(const folly::dynamic& obj, void* context); diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 340fdd5412a3..2fc6a903d06e 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -198,6 +198,11 @@ class QueryConfig { static constexpr const char* kAggregationSpillEnabled = "aggregation_spill_enabled"; + /// Partial aggregation spilling flag, only applies if "spill_enabled" flag is + /// set. + static constexpr const char* kPartialAggregationSpillEnabled = + "partial_aggregation_spill_enabled"; + /// Join spilling flag, only applies if "spill_enabled" flag is set. static constexpr const char* kJoinSpillEnabled = "join_spill_enabled"; @@ -493,11 +498,17 @@ class QueryConfig { } /// Returns 'is aggregation spilling enabled' flag. Must also check the - /// spillEnabled()!g + /// spillEnabled()! bool aggregationSpillEnabled() const { return get(kAggregationSpillEnabled, true); } + /// Returns 'is partial aggregation spilling enabled' flag. Must also check + /// the spillEnabled()! + bool partialAggregationSpillEnabled() const { + return get(kPartialAggregationSpillEnabled, false); + } + /// Returns 'is join spilling enabled' flag. Must also check the /// spillEnabled()! bool joinSpillEnabled() const { diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index 9ecf91f79a50..636e8a9170a3 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -725,6 +725,7 @@ bool GroupingSet::getOutput( } if (hasSpilled()) { + spill(); return getOutputWithSpill(maxOutputRows, maxOutputBytes, result); } VELOX_CHECK(!isDistinct()); @@ -826,7 +827,7 @@ const HashLookup& GroupingSet::hashLookup() const { void GroupingSet::ensureInputFits(const RowVectorPtr& input) { // Spilling is considered if this is a final or single aggregation and // spillPath is set. - if (isPartial_ || spillConfig_ == nullptr) { + if (spillConfig_ == nullptr) { return; } @@ -909,7 +910,7 @@ void GroupingSet::ensureOutputFits() { // to reserve memory for the output as we can't reclaim much memory from this // operator itself. The output processing can reclaim memory from the other // operator or query through memory arbitration. - if (isPartial_ || spillConfig_ == nullptr || hasSpilled()) { + if (spillConfig_ == nullptr || hasSpilled()) { return; } @@ -929,6 +930,9 @@ void GroupingSet::ensureOutputFits() { return; } } + if (hasSpilled()) { + return; + } spill(RowContainerIterator{}); } @@ -955,7 +959,6 @@ void GroupingSet::spill() { if (table_ == nullptr || table_->numDistinct() == 0) { return; } - if (!hasSpilled()) { auto rows = table_->rows(); VELOX_DCHECK(pool_.trackUsage()); @@ -1045,7 +1048,16 @@ bool GroupingSet::getOutputWithSpill( if (merge_ == nullptr) { return false; } - return mergeNext(maxOutputRows, maxOutputBytes, result); + bool hasData = mergeNext(maxOutputRows, maxOutputBytes, result); + if (!hasData) { + // If spill has been finalized, reset merge stream and spiller. This would + // help partial aggregation replay the spilling procedure once needed again. + merge_ = nullptr; + mergeRows_ = nullptr; + mergeArgs_.clear(); + spiller_ = nullptr; + } + return hasData; } bool GroupingSet::mergeNext( diff --git a/velox/exec/tests/SharedArbitratorTest.cpp b/velox/exec/tests/SharedArbitratorTest.cpp index 6577e7d0e70b..3d1e9350c395 100644 --- a/velox/exec/tests/SharedArbitratorTest.cpp +++ b/velox/exec/tests/SharedArbitratorTest.cpp @@ -1053,6 +1053,50 @@ TEST_F(SharedArbitrationTest, reclaimFromDistinctAggregation) { waitForAllTasksToBeDeleted(); } +TEST_F(SharedArbitrationTest, reclaimFromPartialAggregation) { + const uint64_t maxQueryCapacity = 20L << 20; + std::vector vectors = newVectors(1024, maxQueryCapacity * 2); + createDuckDbTable(vectors); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + core::PlanNodeId partialAggNodeId; + core::PlanNodeId finalAggNodeId; + std::shared_ptr queryCtx = newQueryCtx(maxQueryCapacity); + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->path) + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kPartialAggregationSpillEnabled, "true") + .config(core::QueryConfig::kAggregationSpillEnabled, "true") + .config( + core::QueryConfig::kMaxPartialAggregationMemory, + std::to_string(1LL << 30)) // disable flush + .config( + core::QueryConfig::kMaxExtendedPartialAggregationMemory, + std::to_string(1LL << 30)) // disable flush + .config( + core::QueryConfig::kAbandonPartialAggregationMinPct, + "200") // avoid abandoning + .config( + core::QueryConfig::kAbandonPartialAggregationMinRows, + std::to_string(1LL << 30)) // avoid abandoning + .queryCtx(queryCtx) + .plan(PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"count(1)"}) + .capturePlanNodeId(partialAggNodeId) + .finalAggregation() + .capturePlanNodeId(finalAggNodeId) + .planNode()) + .assertResults("SELECT c0, count(1) FROM tmp GROUP BY c0"); + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& partialStats = taskStats.at(partialAggNodeId); + auto& finalStats = taskStats.at(finalAggNodeId); + ASSERT_GT(partialStats.spilledBytes, 0); + ASSERT_GT(finalStats.spilledBytes, 0); + task.reset(); + waitForAllTasksToBeDeleted(); +} + DEBUG_ONLY_TEST_F(SharedArbitrationTest, reclaimFromAggregationOnNoMoreInput) { const int numVectors = 32; std::vector vectors; diff --git a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp index dc24b2095381..528786dbc9a5 100644 --- a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp @@ -639,6 +639,19 @@ class ApproxPercentileAggregate : public exec::Aggregate { DecodedVector decodedDigest_; private: + bool isConstantVector(const VectorPtr& vec) { + if (vec->isConstantEncoding()) { + return true; + } + VELOX_USER_CHECK(vec->size() > 0); + for (vector_size_t i = 1; i < vec->size(); ++i) { + if (!vec->equalValueAt(vec.get(), i, 0)) { + return false; + } + } + return true; + } + template void addIntermediateImpl( std::conditional_t group, @@ -650,7 +663,8 @@ class ApproxPercentileAggregate : public exec::Aggregate { if constexpr (checkIntermediateInputs) { VELOX_USER_CHECK(rowVec); for (int i = kPercentiles; i <= kAccuracy; ++i) { - VELOX_USER_CHECK(rowVec->childAt(i)->isConstantEncoding()); + VELOX_USER_CHECK(isConstantVector( + rowVec->childAt(i))); // spilling flats constant encoding } for (int i = kK; i <= kMaxValue; ++i) { VELOX_USER_CHECK(rowVec->childAt(i)->isFlatEncoding()); @@ -677,10 +691,9 @@ class ApproxPercentileAggregate : public exec::Aggregate { } DecodedVector percentiles(*rowVec->childAt(kPercentiles), *baseRows); - auto percentileIsArray = - rowVec->childAt(kPercentilesIsArray)->asUnchecked>(); - auto accuracy = - rowVec->childAt(kAccuracy)->asUnchecked>(); + DecodedVector percentileIsArray( + *rowVec->childAt(kPercentilesIsArray), *baseRows); + DecodedVector accuracy(*rowVec->childAt(kAccuracy), *baseRows); auto k = rowVec->childAt(kK)->asUnchecked>(); auto n = rowVec->childAt(kN)->asUnchecked>(); auto minValue = rowVec->childAt(kMinValue)->asUnchecked>(); @@ -710,7 +723,7 @@ class ApproxPercentileAggregate : public exec::Aggregate { return; } int i = decoded.index(row); - if (percentileIsArray->isNullAt(i)) { + if (percentileIsArray.isNullAt(i)) { return; } if (!accumulator) { @@ -720,10 +733,10 @@ class ApproxPercentileAggregate : public exec::Aggregate { percentilesBase->elements()->asFlatVector(); if constexpr (checkIntermediateInputs) { VELOX_USER_CHECK(percentileBaseElements); - VELOX_USER_CHECK(!percentilesBase->isNullAt(indexInBaseVector)); + VELOX_USER_CHECK(!percentiles.isNullAt(indexInBaseVector)); } - bool isArray = percentileIsArray->valueAt(i); + bool isArray = percentileIsArray.valueAt(i); const double* data; vector_size_t len; std::vector isNull; @@ -731,8 +744,8 @@ class ApproxPercentileAggregate : public exec::Aggregate { percentilesBase, indexInBaseVector, data, len, isNull); checkSetPercentile(isArray, data, len, isNull); - if (!accuracy->isNullAt(i)) { - checkSetAccuracy(accuracy->valueAt(i)); + if (!accuracy.isNullAt(i)) { + checkSetAccuracy(accuracy.valueAt(i)); } } if constexpr (kSingleGroup) {