Skip to content

Commit

Permalink
add agg info
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Dec 4, 2023
1 parent 023630d commit b54af7b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
40 changes: 21 additions & 19 deletions velox/exec/StreamingAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ void StreamingAggregation::initialize() {
}

const auto numAggregates = aggregationNode_->aggregates().size();
aggregates_.reserve(numAggregates);
std::vector<Accumulator> accumulators;
accumulators.reserve(aggregates_.size());
accumulators.reserve(numAggregates);
std::vector<std::optional<column_index_t>> maskChannels;
maskChannels.reserve(numAggregates);
for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode_->aggregates()[i];

AggregateInfo info;
if (!aggregate.sortingKeys.empty()) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over sorted inputs yet");
Expand All @@ -72,8 +71,8 @@ void StreamingAggregation::initialize() {
"Streaming aggregation doesn't support aggregations over distinct inputs yet");
}

std::vector<column_index_t> channels;
std::vector<VectorPtr> constants;
auto& channels = info.inputs;
auto& constants = info.constantInputs;
for (auto& arg : aggregate.call->inputs()) {
channels.push_back(exprToChannel(arg.get(), inputType));
if (channels.back() == kConstantChannel) {
Expand All @@ -92,21 +91,21 @@ void StreamingAggregation::initialize() {
}

const auto& aggResultType = outputType_->childAt(numKeys + i);
aggregates_.push_back(Aggregate::create(
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(aggregationNode_->step())
? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
aggResultType,
operatorCtx_->driverCtx()->queryConfig()));
args_.push_back(channels);
constantArgs_.push_back(constants);
operatorCtx_->driverCtx()->queryConfig());

const auto intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);
accumulators.push_back(
Accumulator{aggregates_.back().get(), std::move(intermediateType)});
Accumulator{info.function.get(), std::move(intermediateType)});

aggregates_.emplace_back(std::move(info));
}

if (aggregationNode_->ignoreNullKeys()) {
Expand All @@ -128,10 +127,11 @@ void StreamingAggregation::initialize() {
pool());

for (auto i = 0; i < aggregates_.size(); ++i) {
aggregates_[i]->setAllocator(&rows_->stringAllocator());
auto& function = aggregates_[i].function;
function->setAllocator(&rows_->stringAllocator());

const auto rowColumn = rows_->columnAt(numKeys + i);
aggregates_[i]->setOffsets(
function->setOffsets(
rowColumn.offset(),
rowColumn.nullByte(),
rowColumn.nullMask(),
Expand Down Expand Up @@ -203,7 +203,7 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) {

auto numKeys = groupingKeys_.size();
for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;
auto& result = output->childAt(numKeys + i);
if (isPartialOutput(step_)) {
aggregate->extractAccumulators(groups_.data(), numGroups, &result);
Expand Down Expand Up @@ -264,14 +264,16 @@ const SelectivityVector& StreamingAggregation::getSelectivityVector(

void StreamingAggregation::evaluateAggregates() {
for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;
auto& inputs = aggregates_.at(i).inputs;
auto& constantInputs = aggregates_.at(i).constantInputs;

std::vector<VectorPtr> args;
for (auto j = 0; j < args_[i].size(); ++j) {
if (args_[i][j] == kConstantChannel) {
args.push_back(constantArgs_[i][j]);
for (auto j = 0; j < inputs.size(); ++j) {
if (inputs[j] == kConstantChannel) {
args.push_back(constantInputs[j]);
} else {
args.push_back(input_->childAt(args_[i][j]));
args.push_back(input_->childAt(inputs[j]));
}
}

Expand Down Expand Up @@ -315,7 +317,7 @@ RowVectorPtr StreamingAggregation::getOutput() {
std::iota(newGroups.begin(), newGroups.end(), numPrevGroups);

for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;

aggregate->initializeNewGroups(
groups_.data(), folly::Range(newGroups.data(), newGroups.size()));
Expand Down
5 changes: 2 additions & 3 deletions velox/exec/StreamingAggregation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/exec/Aggregate.h"
#include "velox/exec/AggregateInfo.h"
#include "velox/exec/AggregationMasks.h"
#include "velox/exec/Operator.h"

Expand Down Expand Up @@ -79,10 +80,8 @@ class StreamingAggregation : public Operator {
const core::AggregationNode::Step step_;

std::vector<column_index_t> groupingKeys_;
std::vector<std::unique_ptr<Aggregate>> aggregates_;
std::vector<AggregateInfo> aggregates_;
std::unique_ptr<AggregationMasks> masks_;
std::vector<std::vector<column_index_t>> args_;
std::vector<std::vector<VectorPtr>> constantArgs_;
std::vector<DecodedVector> decodedKeys_;

// Storage of grouping keys and accumulators.
Expand Down

0 comments on commit b54af7b

Please sign in to comment.