Skip to content

Commit

Permalink
add agg info
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Dec 5, 2023
1 parent 023630d commit 75a8bdb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
54 changes: 30 additions & 24 deletions velox/exec/StreamingAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ void StreamingAggregation::initialize() {
const auto numAggregates = aggregationNode_->aggregates().size();
aggregates_.reserve(numAggregates);
std::vector<Accumulator> accumulators;
accumulators.reserve(aggregates_.size());
std::vector<std::optional<column_index_t>> maskChannels;
maskChannels.reserve(numAggregates);
accumulators.reserve(numAggregates);

for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode_->aggregates()[i];

AggregateInfo info;
if (!aggregate.sortingKeys.empty()) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over sorted inputs yet");
Expand All @@ -72,8 +71,8 @@ void StreamingAggregation::initialize() {
"Streaming aggregation doesn't support aggregations over distinct inputs yet");
}

std::vector<column_index_t> channels;
std::vector<VectorPtr> constants;
auto& channels = info.inputs;
auto& constants = info.constantInputs;
for (auto& arg : aggregate.call->inputs()) {
channels.push_back(exprToChannel(arg.get(), inputType));
if (channels.back() == kConstantChannel) {
Expand All @@ -86,35 +85,39 @@ void StreamingAggregation::initialize() {
}

if (const auto& mask = aggregate.mask) {
maskChannels.emplace_back(inputType->asRow().getChildIdx(mask->name()));
info.mask = inputType->asRow().getChildIdx(mask->name());
} else {
maskChannels.emplace_back(std::nullopt);
info.mask = std::nullopt;
}

const auto& aggResultType = outputType_->childAt(numKeys + i);
aggregates_.push_back(Aggregate::create(
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(aggregationNode_->step())
? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
aggResultType,
operatorCtx_->driverCtx()->queryConfig()));
args_.push_back(channels);
constantArgs_.push_back(constants);
operatorCtx_->driverCtx()->queryConfig());

const auto intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);
accumulators.push_back(
Accumulator{aggregates_.back().get(), std::move(intermediateType)});
accumulators.emplace_back(info.function.get(), intermediateType);

aggregates_.emplace_back(std::move(info));
}

if (aggregationNode_->ignoreNullKeys()) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support ignoring null keys yet");
}

masks_ = std::make_unique<AggregationMasks>(std::move(maskChannels));
std::vector<std::optional<column_index_t>> masks;
masks.reserve(numAggregates);
for (const auto& aggregate : aggregates_) {
masks.emplace_back(aggregate.mask);
}
masks_ = std::make_unique<AggregationMasks>(std::move(masks));

rows_ = std::make_unique<RowContainer>(
groupingKeyTypes,
Expand All @@ -128,10 +131,11 @@ void StreamingAggregation::initialize() {
pool());

for (auto i = 0; i < aggregates_.size(); ++i) {
aggregates_[i]->setAllocator(&rows_->stringAllocator());
auto& function = aggregates_[i].function;
function->setAllocator(&rows_->stringAllocator());

const auto rowColumn = rows_->columnAt(numKeys + i);
aggregates_[i]->setOffsets(
function->setOffsets(
rowColumn.offset(),
rowColumn.nullByte(),
rowColumn.nullMask(),
Expand Down Expand Up @@ -203,7 +207,7 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) {

auto numKeys = groupingKeys_.size();
for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;
auto& result = output->childAt(numKeys + i);
if (isPartialOutput(step_)) {
aggregate->extractAccumulators(groups_.data(), numGroups, &result);
Expand Down Expand Up @@ -264,14 +268,16 @@ const SelectivityVector& StreamingAggregation::getSelectivityVector(

void StreamingAggregation::evaluateAggregates() {
for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;
auto& inputs = aggregates_.at(i).inputs;
auto& constantInputs = aggregates_.at(i).constantInputs;

std::vector<VectorPtr> args;
for (auto j = 0; j < args_[i].size(); ++j) {
if (args_[i][j] == kConstantChannel) {
args.push_back(constantArgs_[i][j]);
for (auto j = 0; j < inputs.size(); ++j) {
if (inputs[j] == kConstantChannel) {
args.push_back(constantInputs[j]);
} else {
args.push_back(input_->childAt(args_[i][j]));
args.push_back(input_->childAt(inputs[j]));
}
}

Expand Down Expand Up @@ -315,7 +321,7 @@ RowVectorPtr StreamingAggregation::getOutput() {
std::iota(newGroups.begin(), newGroups.end(), numPrevGroups);

for (auto i = 0; i < aggregates_.size(); ++i) {
auto& aggregate = aggregates_[i];
auto& aggregate = aggregates_.at(i).function;

aggregate->initializeNewGroups(
groups_.data(), folly::Range(newGroups.data(), newGroups.size()));
Expand Down
5 changes: 2 additions & 3 deletions velox/exec/StreamingAggregation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/exec/Aggregate.h"
#include "velox/exec/AggregateInfo.h"
#include "velox/exec/AggregationMasks.h"
#include "velox/exec/Operator.h"

Expand Down Expand Up @@ -79,10 +80,8 @@ class StreamingAggregation : public Operator {
const core::AggregationNode::Step step_;

std::vector<column_index_t> groupingKeys_;
std::vector<std::unique_ptr<Aggregate>> aggregates_;
std::vector<AggregateInfo> aggregates_;
std::unique_ptr<AggregationMasks> masks_;
std::vector<std::vector<column_index_t>> args_;
std::vector<std::vector<VectorPtr>> constantArgs_;
std::vector<DecodedVector> decodedKeys_;

// Storage of grouping keys and accumulators.
Expand Down

0 comments on commit 75a8bdb

Please sign in to comment.