diff --git a/velox/exec/StreamingAggregation.cpp b/velox/exec/StreamingAggregation.cpp index cace7e461a06..96ef96ee577a 100644 --- a/velox/exec/StreamingAggregation.cpp +++ b/velox/exec/StreamingAggregation.cpp @@ -56,12 +56,11 @@ void StreamingAggregation::initialize() { const auto numAggregates = aggregationNode_->aggregates().size(); aggregates_.reserve(numAggregates); std::vector accumulators; - accumulators.reserve(aggregates_.size()); - std::vector> maskChannels; - maskChannels.reserve(numAggregates); + accumulators.reserve(numAggregates); + for (auto i = 0; i < numAggregates; i++) { const auto& aggregate = aggregationNode_->aggregates()[i]; - + AggregateInfo info; if (!aggregate.sortingKeys.empty()) { VELOX_UNSUPPORTED( "Streaming aggregation doesn't support aggregations over sorted inputs yet"); @@ -72,8 +71,8 @@ void StreamingAggregation::initialize() { "Streaming aggregation doesn't support aggregations over distinct inputs yet"); } - std::vector channels; - std::vector constants; + auto& channels = info.inputs; + auto& constants = info.constantInputs; for (auto& arg : aggregate.call->inputs()) { channels.push_back(exprToChannel(arg.get(), inputType)); if (channels.back() == kConstantChannel) { @@ -86,27 +85,26 @@ void StreamingAggregation::initialize() { } if (const auto& mask = aggregate.mask) { - maskChannels.emplace_back(inputType->asRow().getChildIdx(mask->name())); + info.mask = inputType->asRow().getChildIdx(mask->name()); } else { - maskChannels.emplace_back(std::nullopt); + info.mask = std::nullopt; } const auto& aggResultType = outputType_->childAt(numKeys + i); - aggregates_.push_back(Aggregate::create( + info.function = Aggregate::create( aggregate.call->name(), isPartialOutput(aggregationNode_->step()) ? core::AggregationNode::Step::kPartial : core::AggregationNode::Step::kSingle, aggregate.rawInputTypes, aggResultType, - operatorCtx_->driverCtx()->queryConfig())); - args_.push_back(channels); - constantArgs_.push_back(constants); + operatorCtx_->driverCtx()->queryConfig()); const auto intermediateType = Aggregate::intermediateType( aggregate.call->name(), aggregate.rawInputTypes); - accumulators.push_back( - Accumulator{aggregates_.back().get(), std::move(intermediateType)}); + accumulators.emplace_back(info.function.get(), intermediateType); + + aggregates_.emplace_back(std::move(info)); } if (aggregationNode_->ignoreNullKeys()) { @@ -114,7 +112,12 @@ void StreamingAggregation::initialize() { "Streaming aggregation doesn't support ignoring null keys yet"); } - masks_ = std::make_unique(std::move(maskChannels)); + std::vector> masks; + masks.reserve(numAggregates); + for (const auto& aggregate : aggregates_) { + masks.emplace_back(aggregate.mask); + } + masks_ = std::make_unique(std::move(masks)); rows_ = std::make_unique( groupingKeyTypes, @@ -128,10 +131,11 @@ void StreamingAggregation::initialize() { pool()); for (auto i = 0; i < aggregates_.size(); ++i) { - aggregates_[i]->setAllocator(&rows_->stringAllocator()); + auto& function = aggregates_[i].function; + function->setAllocator(&rows_->stringAllocator()); const auto rowColumn = rows_->columnAt(numKeys + i); - aggregates_[i]->setOffsets( + function->setOffsets( rowColumn.offset(), rowColumn.nullByte(), rowColumn.nullMask(), @@ -203,7 +207,7 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { auto numKeys = groupingKeys_.size(); for (auto i = 0; i < aggregates_.size(); ++i) { - auto& aggregate = aggregates_[i]; + auto& aggregate = aggregates_.at(i).function; auto& result = output->childAt(numKeys + i); if (isPartialOutput(step_)) { aggregate->extractAccumulators(groups_.data(), numGroups, &result); @@ -264,14 +268,16 @@ const SelectivityVector& StreamingAggregation::getSelectivityVector( void StreamingAggregation::evaluateAggregates() { for (auto i = 0; i < aggregates_.size(); ++i) { - auto& aggregate = aggregates_[i]; + auto& aggregate = aggregates_.at(i).function; + auto& inputs = aggregates_.at(i).inputs; + auto& constantInputs = aggregates_.at(i).constantInputs; std::vector args; - for (auto j = 0; j < args_[i].size(); ++j) { - if (args_[i][j] == kConstantChannel) { - args.push_back(constantArgs_[i][j]); + for (auto j = 0; j < inputs.size(); ++j) { + if (inputs[j] == kConstantChannel) { + args.push_back(constantInputs[j]); } else { - args.push_back(input_->childAt(args_[i][j])); + args.push_back(input_->childAt(inputs[j])); } } @@ -315,7 +321,7 @@ RowVectorPtr StreamingAggregation::getOutput() { std::iota(newGroups.begin(), newGroups.end(), numPrevGroups); for (auto i = 0; i < aggregates_.size(); ++i) { - auto& aggregate = aggregates_[i]; + auto& aggregate = aggregates_.at(i).function; aggregate->initializeNewGroups( groups_.data(), folly::Range(newGroups.data(), newGroups.size())); diff --git a/velox/exec/StreamingAggregation.h b/velox/exec/StreamingAggregation.h index e60f57760572..3307dfaaaee2 100644 --- a/velox/exec/StreamingAggregation.h +++ b/velox/exec/StreamingAggregation.h @@ -16,6 +16,7 @@ #pragma once #include "velox/exec/Aggregate.h" +#include "velox/exec/AggregateInfo.h" #include "velox/exec/AggregationMasks.h" #include "velox/exec/Operator.h" @@ -79,10 +80,8 @@ class StreamingAggregation : public Operator { const core::AggregationNode::Step step_; std::vector groupingKeys_; - std::vector> aggregates_; + std::vector aggregates_; std::unique_ptr masks_; - std::vector> args_; - std::vector> constantArgs_; std::vector decodedKeys_; // Storage of grouping keys and accumulators.