Skip to content

Commit

Permalink
Extract toAggregateInfo(AggregationNode) helper function (facebookinc…
Browse files Browse the repository at this point in the history
…ubator#7859)

Summary:
Refactor HashAggregation to extract toAggregateInfo(AggregationNode).

Use the new function in StreamingAggregation to reduce copy-paste.

Part of facebookincubator#7665

Pull Request resolved: facebookincubator#7859

Differential Revision: D51895447

Pulled By: mbasmanova

fbshipit-source-id: ce6d23607a6171d02c489251a268175ffd2363f6
  • Loading branch information
duanmeng authored and facebook-github-bot committed Dec 7, 2023
1 parent 13d54c4 commit 4827537
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 185 deletions.
147 changes: 147 additions & 0 deletions velox/exec/AggregateInfo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/exec/AggregateInfo.h"
#include "velox/exec/Aggregate.h"
#include "velox/exec/Operator.h"
#include "velox/expression/Expr.h"

namespace facebook::velox::exec {

namespace {
std::vector<core::LambdaTypedExprPtr> extractLambdaInputs(
const core::AggregationNode::Aggregate& aggregate) {
std::vector<core::LambdaTypedExprPtr> lambdas;
for (const auto& arg : aggregate.call->inputs()) {
if (auto lambda =
std::dynamic_pointer_cast<const core::LambdaTypedExpr>(arg)) {
lambdas.push_back(lambda);
}
}

return lambdas;
}
} // namespace

std::vector<AggregateInfo> toAggregateInfo(
const core::AggregationNode& aggregationNode,
const OperatorCtx& operatorCtx,
uint32_t numKeys,
std::shared_ptr<core::ExpressionEvaluator>& expressionEvaluator,
bool isStreaming) {
const auto numAggregates = aggregationNode.aggregates().size();
std::vector<AggregateInfo> aggregates;
aggregates.reserve(numAggregates);

const auto& inputType = aggregationNode.sources()[0]->outputType();
const auto& outputType = aggregationNode.outputType();
const auto step = aggregationNode.step();

for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode.aggregates()[i];

// TODO: Add support for StreamingAggregation
if (isStreaming && !aggregate.sortingKeys.empty()) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over sorted inputs yet");
}
if (isStreaming && aggregate.distinct) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over distinct inputs yet");
}

AggregateInfo info;
// Populate input.
auto& channels = info.inputs;
auto& constants = info.constantInputs;
for (const auto& arg : aggregate.call->inputs()) {
if (auto field =
dynamic_cast<const core::FieldAccessTypedExpr*>(arg.get())) {
channels.push_back(inputType->getChildIdx(field->name()));
constants.push_back(nullptr);
} else if (
auto constant =
dynamic_cast<const core::ConstantTypedExpr*>(arg.get())) {
channels.push_back(kConstantChannel);
constants.push_back(constant->toConstantVector(operatorCtx.pool()));
} else if (
auto lambda = dynamic_cast<const core::LambdaTypedExpr*>(arg.get())) {
VELOX_USER_CHECK(
!isStreaming,
"StreamingAggregation doesn't support lambda functions yet.");
for (const auto& name : lambda->signature()->names()) {
if (auto captureIndex = inputType->getChildIdxIfExists(name)) {
channels.push_back(captureIndex.value());
constants.push_back(nullptr);
}
}
} else {
VELOX_FAIL(
"Expression must be field access, constant, or "
"lambda (HashAggregation): {}",
arg->toString());
}
}

info.distinct = aggregate.distinct;
info.intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);

// Setup aggregation mask: convert the Variable Reference name to the
// channel (projection) index, if there is a mask.
if (const auto& mask = aggregate.mask) {
info.mask = inputType->asRow().getChildIdx(mask->name());
} else {
info.mask = std::nullopt;
}

auto index = numKeys + i;
const auto& aggResultType = outputType->childAt(index);
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(step) ? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
aggResultType,
operatorCtx.driverCtx()->queryConfig());

if (!isStreaming) {
auto lambdas = extractLambdaInputs(aggregate);
if (!lambdas.empty()) {
if (expressionEvaluator == nullptr) {
expressionEvaluator = std::make_shared<SimpleExpressionEvaluator>(
operatorCtx.execCtx()->queryCtx(), operatorCtx.execCtx()->pool());
}
info.function->setLambdaExpressions(lambdas, expressionEvaluator);
}
}

// Sorting keys and orders.
const auto numSortingKeys = aggregate.sortingKeys.size();
VELOX_CHECK_EQ(numSortingKeys, aggregate.sortingOrders.size());
info.sortingOrders = aggregate.sortingOrders;
info.sortingKeys.reserve(numSortingKeys);
for (const auto& key : aggregate.sortingKeys) {
info.sortingKeys.push_back(exprToChannel(key.get(), inputType));
}

info.output = index;
aggregates.emplace_back(std::move(info));
}
return aggregates;
}

} // namespace facebook::velox::exec
25 changes: 25 additions & 0 deletions velox/exec/AggregateInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,29 @@ struct AggregateInfo {
/// Type of intermediate results. Used for spilling.
TypePtr intermediateType;
};

class OperatorCtx;

/// Translate an AggregationNode to a list of AggregationInfo, which could be
/// a hash aggregation plan node or a streaming aggregation plan node.
///
/// @param aggregationNode Plan node of this aggregation.
/// @param operatorCtx Operator context.
/// @param numKeys Number of grouping keys.
/// @param expressionEvaluator An Expression evaluator. It is used by an
/// aggregate operator to compile and eval lambda expression. It should be
/// initiated/assigned for at most one time.
/// @param isStreaming Indicate whether this aggregation is streaming or not.
/// Pass true if the aggregate operator is a StreamingAggregation and false if
/// the aggregate operator is a HashAggregation. This parameter will be
/// removed after sorted, distinct aggregation, and lambda functions support
/// are added to StreamingAggregation.
/// @return List of AggregationInfo.
std::vector<AggregateInfo> toAggregateInfo(
const core::AggregationNode& aggregationNode,
const OperatorCtx& operatorCtx,
uint32_t numKeys,
std::shared_ptr<core::ExpressionEvaluator>& expressionEvaluator,
bool isStreaming = false);

} // namespace facebook::velox::exec
1 change: 1 addition & 0 deletions velox/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(
AggregateCompanionAdapter.cpp
AggregateCompanionSignatures.cpp
AggregateFunctionRegistry.cpp
AggregateInfo.cpp
AggregationMasks.cpp
AggregateWindow.cpp
ArrowStream.cpp
Expand Down
113 changes: 2 additions & 111 deletions velox/exec/HashAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,11 @@
*/
#include "velox/exec/HashAggregation.h"
#include <optional>
#include "velox/exec/Aggregate.h"
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/SortedAggregations.h"
#include "velox/exec/Task.h"
#include "velox/expression/Expr.h"

namespace facebook::velox::exec {

namespace {
std::vector<core::LambdaTypedExprPtr> extractLambdaInputs(
const core::AggregationNode::Aggregate& aggregate) {
std::vector<core::LambdaTypedExprPtr> lambdas;
for (const auto& arg : aggregate.call->inputs()) {
if (auto lambda =
std::dynamic_pointer_cast<const core::LambdaTypedExpr>(arg)) {
lambdas.push_back(lambda);
}
}

return lambdas;
}

void populateAggregateInputs(
const core::AggregationNode::Aggregate& aggregate,
const RowType& inputType,
AggregateInfo& info,
memory::MemoryPool* pool) {
auto& channels = info.inputs;
auto& constants = info.constantInputs;

for (const auto& arg : aggregate.call->inputs()) {
if (auto field =
dynamic_cast<const core::FieldAccessTypedExpr*>(arg.get())) {
channels.push_back(inputType.getChildIdx(field->name()));
constants.push_back(nullptr);
} else if (
auto constant =
dynamic_cast<const core::ConstantTypedExpr*>(arg.get())) {
channels.push_back(kConstantChannel);
constants.push_back(constant->toConstantVector(pool));
} else if (
auto lambda = dynamic_cast<const core::LambdaTypedExpr*>(arg.get())) {
for (const auto& name : lambda->signature()->names()) {
if (auto captureIndex = inputType.getChildIdxIfExists(name)) {
channels.push_back(captureIndex.value());
constants.push_back(nullptr);
}
}
} else {
VELOX_FAIL(
"Expression must be field access, constant, or lambda: {}",
arg->toString());
}
}
}

} // namespace

HashAggregation::HashAggregation(
int32_t operatorId,
DriverCtx* driverCtx,
Expand Down Expand Up @@ -107,7 +54,6 @@ void HashAggregation::initialize() {
VELOX_CHECK(pool()->trackUsage());

auto inputType = aggregationNode_->sources()[0]->outputType();

auto hashers =
createVectorHashers(inputType, aggregationNode_->groupingKeys());
auto numHashers = hashers.size();
Expand All @@ -119,64 +65,9 @@ void HashAggregation::initialize() {
preGroupedChannels.push_back(channel);
}

const auto numAggregates = aggregationNode_->aggregates().size();
std::vector<AggregateInfo> aggregateInfos;
aggregateInfos.reserve(numAggregates);

std::shared_ptr<core::ExpressionEvaluator> expressionEvaluator;

for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode_->aggregates()[i];

AggregateInfo info;
info.distinct = aggregate.distinct;
populateAggregateInputs(aggregate, inputType->asRow(), info, pool());

info.intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);

// Setup aggregation mask: convert the Variable Reference name to the
// channel (projection) index, if there is a mask.
if (const auto& mask = aggregate.mask) {
if (mask != nullptr) {
info.mask = inputType->asRow().getChildIdx(mask->name());
}
}

const auto& resultType = outputType_->childAt(numHashers + i);
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(aggregationNode_->step())
? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
resultType,
operatorCtx_->driverCtx()->queryConfig());

auto lambdas = extractLambdaInputs(aggregate);
if (!lambdas.empty()) {
if (expressionEvaluator == nullptr) {
expressionEvaluator = std::make_shared<SimpleExpressionEvaluator>(
operatorCtx_->execCtx()->queryCtx(),
operatorCtx_->execCtx()->pool());
}
info.function->setLambdaExpressions(lambdas, expressionEvaluator);
}

info.output = numHashers + i;

// Sorting keys and orders.
const auto numSortingKeys = aggregate.sortingKeys.size();
VELOX_CHECK_EQ(numSortingKeys, aggregate.sortingOrders.size());
info.sortingOrders = aggregate.sortingOrders;

info.sortingKeys.reserve(numSortingKeys);
for (const auto& key : aggregate.sortingKeys) {
info.sortingKeys.push_back(exprToChannel(key.get(), inputType));
}

aggregateInfos.emplace_back(std::move(info));
}
std::vector<AggregateInfo> aggregateInfos = toAggregateInfo(
*aggregationNode_, *operatorCtx_, numHashers, expressionEvaluator);

// Check that aggregate result type match the output type.
for (auto i = 0; i < aggregateInfos.size(); i++) {
Expand Down
Loading

0 comments on commit 4827537

Please sign in to comment.