diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index d79483e7424ae..1bab06a1a8336 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -340,4 +340,24 @@ void Aggregate::clearInternal() { numNulls_ = 0; } +void Aggregate::singleInputAsIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const { + VELOX_CHECK_EQ(args.size(), 1); + const auto& input = args[0]; + if (rows.isAllSelected()) { + result = input; + return; + } + VELOX_CHECK_NOT_NULL(result); + // Set result to NULL for rows that are masked out. + { + auto nulls = allocateNulls(rows.size(), allocator_->pool(), bits::kNull); + rows.clearNulls(nulls); + result->setNulls(nulls); + } + result->copy(input.get(), rows, nullptr); +} + } // namespace facebook::velox::exec diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 9a827db4d8230..d6bc12aefcdee 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -337,6 +337,13 @@ class Aggregate { // 'groups'. No-op for fixed length accumulators. virtual void destroyInternal(folly::Range groups) {} + // Helper function to pass single input argument directly as intermediate + // result. + void singleInputAsIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const; + // Shorthand for maintaining accumulator variable length size in // accumulator update methods. Use like: { auto tracker = // trackRowSize(group); update(group); } diff --git a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp index 55ed6fdc03bab..b0f7dec154d67 100644 --- a/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxDistinctAggregate.cpp @@ -150,6 +150,17 @@ class ApproxDistinctAggregate : public exec::Aggregate { return false; } + bool supportsToIntermediate() const final { + return hllAsRawInput_; + } + + void toIntermediate( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) const final { + singleInputAsIntermediate(rows, args, result); + } + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { if (hllAsFinalResult_) { diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index dadca2c4f2d94..dd30a3fb9b4a5 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -55,24 +55,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { const SelectivityVector& rows, std::vector& args, VectorPtr& result) const override { - const auto& input = args[0]; - if (rows.isAllSelected()) { - result = input; - return; - } - - auto* pool = BaseAggregate::allocator_->pool(); - - result = BaseVector::create(input->type(), rows.size(), pool); - - // Set result to NULL for rows that are masked out. - { - BufferPtr nulls = allocateNulls(rows.size(), pool, bits::kNull); - rows.clearNulls(nulls); - result->setNulls(nulls); - } - - result->copy(input.get(), rows, nullptr); + this->singleInputAsIntermediate(rows, args, result); } void extractValues(char** groups, int32_t numGroups, VectorPtr* result) diff --git a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp index 5156a084ce425..8db2126a21516 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxDistinctTest.cpp @@ -407,5 +407,20 @@ TEST_F(ApproxDistinctTest, mergeWithEmpty) { ASSERT_EQ(readSingleValue(op).value(), 499); } +TEST_F(ApproxDistinctTest, toIntermediate) { + constexpr int kSize = 1000; + auto input = makeRowVector({ + makeFlatVector(kSize, folly::identity), + makeConstant(1, kSize), + }); + auto plan = PlanBuilder() + .values({input}) + .singleAggregation({"c0"}, {"approx_set(c1)"}) + .planNode(); + auto digests = split(AssertQueryBuilder(plan).copyResults(pool()), 2); + testAggregations( + digests, {"c0"}, {"merge(a0)"}, {"c0", "cardinality(a0)"}, {input}); +} + } // namespace } // namespace facebook::velox::aggregate::test