diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index 2bc658ba0a9c8..cc48174034bbd 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -1205,6 +1205,8 @@ void GroupingSet::updateRow(SpillMergeStream& input, char* row) { } void GroupingSet::abandonPartialAggregation() { + VELOX_CHECK(!hasSpilled()) + abandonedPartialAggregation_ = true; allSupportToIntermediate_ = true; for (auto& aggregate : aggregates_) { diff --git a/velox/exec/tests/SharedArbitratorTest.cpp b/velox/exec/tests/SharedArbitratorTest.cpp index aa77df460a649..7f7983e7db20d 100644 --- a/velox/exec/tests/SharedArbitratorTest.cpp +++ b/velox/exec/tests/SharedArbitratorTest.cpp @@ -1228,6 +1228,55 @@ TEST_F(SharedArbitrationTest, reclaimFromPartialAggregation) { waitForAllTasksToBeDeleted(); } +TEST_F( + SharedArbitrationTest, + reclaimFromPartialAggregationAndIgnoreFlushingSettings) { + const uint64_t maxQueryCapacity = 20L << 20; + std::vector vectors = newVectors(1024, maxQueryCapacity * 2); + createDuckDbTable(vectors); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + core::PlanNodeId partialAggNodeId; + core::PlanNodeId finalAggNodeId; + std::shared_ptr queryCtx = newQueryCtx(maxQueryCapacity); + auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .spillDirectory(spillDirectory->path) + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kPartialAggregationSpillEnabled, "true") + .config(core::QueryConfig::kAggregationSpillEnabled, "true") + .config( + core::QueryConfig::kMaxPartialAggregationMemory, + std::to_string(1L)) + .config( + core::QueryConfig::kMaxExtendedPartialAggregationMemory, + std::to_string(1L)) + .config( + core::QueryConfig::kAbandonPartialAggregationMinPct, + "200") // avoid abandoning + .config( + core::QueryConfig::kAbandonPartialAggregationMinRows, + std::to_string(1LL << 30)) // avoid abandoning + .queryCtx(queryCtx) + .plan(PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"count(1)"}) + .capturePlanNodeId(partialAggNodeId) + .finalAggregation() + .capturePlanNodeId(finalAggNodeId) + .planNode()) + .assertResults("SELECT c0, count(1) FROM tmp GROUP BY c0"); + auto taskStats = exec::toPlanStats(task->taskStats()); + auto& partialStats = taskStats.at(partialAggNodeId); + auto& finalStats = taskStats.at(finalAggNodeId); + ASSERT_EQ( + partialStats.customStats.find("flushRowCount"), + partialStats.customStats.end()); + ASSERT_GT(partialStats.spilledBytes, 0); + ASSERT_GT(finalStats.spilledBytes, 0); + task.reset(); + waitForAllTasksToBeDeleted(); +} + DEBUG_ONLY_TEST_F(SharedArbitrationTest, reclaimFromAggregationOnNoMoreInput) { const int numVectors = 32; std::vector vectors;