Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract toAggregateInfo(AggregationNode) helper function #7859

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions velox/exec/AggregateUtil.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/exec/AggregateUtil.h"
#include "velox/exec/Aggregate.h"
#include "velox/expression/Expr.h"

namespace facebook::velox::exec {

namespace {
std::vector<core::LambdaTypedExprPtr> extractLambdaInputs(
const core::AggregationNode::Aggregate& aggregate) {
std::vector<core::LambdaTypedExprPtr> lambdas;
for (const auto& arg : aggregate.call->inputs()) {
if (auto lambda =
std::dynamic_pointer_cast<const core::LambdaTypedExpr>(arg)) {
lambdas.push_back(lambda);
}
}

return lambdas;
}
} // namespace

std::vector<AggregateInfo> AggregateUtil::toAggregateInfo(
const core::AggregationNode& aggregationNode,
const OperatorCtx& operatorCtx,
uint32_t numKeys,
std::shared_ptr<core::ExpressionEvaluator>& expressionEvaluator,
bool isStreaming) {
const auto numAggregates = aggregationNode.aggregates().size();
std::vector<AggregateInfo> aggregates;
aggregates.reserve(numAggregates);

const auto& inputType = aggregationNode.sources()[0]->outputType();
const auto& outputType = aggregationNode.outputType();
const auto step = aggregationNode.step();

for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode.aggregates()[i];

// TODO: Add support for StreamingAggregation
if (isStreaming && !aggregate.sortingKeys.empty()) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over sorted inputs yet");
}
if (isStreaming && aggregate.distinct) {
VELOX_UNSUPPORTED(
"Streaming aggregation doesn't support aggregations over distinct inputs yet");
}

AggregateInfo info;
// Populate input.
auto& channels = info.inputs;
auto& constants = info.constantInputs;
for (const auto& arg : aggregate.call->inputs()) {
if (auto field =
dynamic_cast<const core::FieldAccessTypedExpr*>(arg.get())) {
channels.push_back(inputType->getChildIdx(field->name()));
constants.push_back(nullptr);
} else if (
auto constant =
dynamic_cast<const core::ConstantTypedExpr*>(arg.get())) {
channels.push_back(kConstantChannel);
constants.push_back(constant->toConstantVector(operatorCtx.pool()));
} else if (
auto lambda = dynamic_cast<const core::LambdaTypedExpr*>(arg.get())) {
VELOX_USER_CHECK(
!isStreaming,
"StreamingAggregation doesn't support lambda functions yet.");
for (const auto& name : lambda->signature()->names()) {
if (auto captureIndex = inputType->getChildIdxIfExists(name)) {
channels.push_back(captureIndex.value());
constants.push_back(nullptr);
}
}
} else {
VELOX_FAIL(
"Expression must be field access, constant, or "
"lambda (HashAggregation): {}",
arg->toString());
}
}

info.distinct = aggregate.distinct;
info.intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);

// Setup aggregation mask: convert the Variable Reference name to the
// channel (projection) index, if there is a mask.
if (const auto& mask = aggregate.mask) {
info.mask = inputType->asRow().getChildIdx(mask->name());
} else {
info.mask = std::nullopt;
}

auto index = numKeys + i;
const auto& aggResultType = outputType->childAt(index);
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(step) ? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
aggResultType,
operatorCtx.driverCtx()->queryConfig());

if (!isStreaming) {
auto lambdas = extractLambdaInputs(aggregate);
if (!lambdas.empty()) {
if (expressionEvaluator == nullptr) {
expressionEvaluator = std::make_shared<SimpleExpressionEvaluator>(
operatorCtx.execCtx()->queryCtx(), operatorCtx.execCtx()->pool());
}
info.function->setLambdaExpressions(lambdas, expressionEvaluator);
}
}

// Sorting keys and orders.
const auto numSortingKeys = aggregate.sortingKeys.size();
VELOX_CHECK_EQ(numSortingKeys, aggregate.sortingOrders.size());
info.sortingOrders = aggregate.sortingOrders;
info.sortingKeys.reserve(numSortingKeys);
for (const auto& key : aggregate.sortingKeys) {
info.sortingKeys.push_back(exprToChannel(key.get(), inputType));
}

info.output = index;
aggregates.emplace_back(std::move(info));
}
return aggregates;
}

} // namespace facebook::velox::exec
29 changes: 29 additions & 0 deletions velox/exec/AggregateUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "velox/exec/AggregateInfo.h"
#include "velox/exec/Operator.h"

namespace facebook::velox::exec {

// The result of aggregation function registration.
Expand All @@ -34,4 +38,29 @@ struct AggregateRegistrationResult {
}
};

class AggregateUtil {
duanmeng marked this conversation as resolved.
Show resolved Hide resolved
public:
/// Translate an AggregationNode to a list of AggregationInfo, which could be
/// a hash aggregation plan node or a streaming aggregation plan node.
///
/// @param aggregationNode Plan node of this aggregation.
/// @param operatorCtx Operator context.
/// @param numKeys Number of grouping keys.
/// @param expressionEvaluator An Expression evaluator. It is used by an
/// aggregate operator to compile and eval lambda expression. It should be
/// initiated/assigned for at most one time.
/// @param isStreaming Indicate whether this aggregation is streaming or not.
/// Pass true if the aggregate operator is a StreamingAggregation and false if
/// the aggregate operator is a HashAggregation. This parameter will be
/// removed after sorted, distinct aggregation, and lambda functions support
/// are added to StreamingAggregation.
/// @return List of AggregationInfo.
static std::vector<AggregateInfo> toAggregateInfo(
const core::AggregationNode& aggregationNode,
const OperatorCtx& operatorCtx,
uint32_t numKeys,
std::shared_ptr<core::ExpressionEvaluator>& expressionEvaluator,
bool isStreaming = false);
};

} // namespace facebook::velox::exec
1 change: 1 addition & 0 deletions velox/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(
AggregateCompanionAdapter.cpp
AggregateCompanionSignatures.cpp
AggregateFunctionRegistry.cpp
AggregateUtil.cpp
AggregationMasks.cpp
AggregateWindow.cpp
ArrowStream.cpp
Expand Down
113 changes: 2 additions & 111 deletions velox/exec/HashAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,11 @@
*/
#include "velox/exec/HashAggregation.h"
#include <optional>
#include "velox/exec/Aggregate.h"
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/SortedAggregations.h"
#include "velox/exec/Task.h"
#include "velox/expression/Expr.h"

namespace facebook::velox::exec {

namespace {
std::vector<core::LambdaTypedExprPtr> extractLambdaInputs(
const core::AggregationNode::Aggregate& aggregate) {
std::vector<core::LambdaTypedExprPtr> lambdas;
for (const auto& arg : aggregate.call->inputs()) {
if (auto lambda =
std::dynamic_pointer_cast<const core::LambdaTypedExpr>(arg)) {
lambdas.push_back(lambda);
}
}

return lambdas;
}

void populateAggregateInputs(
const core::AggregationNode::Aggregate& aggregate,
const RowType& inputType,
AggregateInfo& info,
memory::MemoryPool* pool) {
auto& channels = info.inputs;
auto& constants = info.constantInputs;

for (const auto& arg : aggregate.call->inputs()) {
if (auto field =
dynamic_cast<const core::FieldAccessTypedExpr*>(arg.get())) {
channels.push_back(inputType.getChildIdx(field->name()));
constants.push_back(nullptr);
} else if (
auto constant =
dynamic_cast<const core::ConstantTypedExpr*>(arg.get())) {
channels.push_back(kConstantChannel);
constants.push_back(constant->toConstantVector(pool));
} else if (
auto lambda = dynamic_cast<const core::LambdaTypedExpr*>(arg.get())) {
for (const auto& name : lambda->signature()->names()) {
if (auto captureIndex = inputType.getChildIdxIfExists(name)) {
channels.push_back(captureIndex.value());
constants.push_back(nullptr);
}
}
} else {
VELOX_FAIL(
"Expression must be field access, constant, or lambda: {}",
arg->toString());
}
}
}

} // namespace

HashAggregation::HashAggregation(
int32_t operatorId,
DriverCtx* driverCtx,
Expand Down Expand Up @@ -107,7 +54,6 @@ void HashAggregation::initialize() {
VELOX_CHECK(pool()->trackUsage());

auto inputType = aggregationNode_->sources()[0]->outputType();

auto hashers =
createVectorHashers(inputType, aggregationNode_->groupingKeys());
auto numHashers = hashers.size();
Expand All @@ -119,64 +65,9 @@ void HashAggregation::initialize() {
preGroupedChannels.push_back(channel);
}

const auto numAggregates = aggregationNode_->aggregates().size();
std::vector<AggregateInfo> aggregateInfos;
aggregateInfos.reserve(numAggregates);

std::shared_ptr<core::ExpressionEvaluator> expressionEvaluator;

for (auto i = 0; i < numAggregates; i++) {
const auto& aggregate = aggregationNode_->aggregates()[i];

AggregateInfo info;
info.distinct = aggregate.distinct;
populateAggregateInputs(aggregate, inputType->asRow(), info, pool());

info.intermediateType = Aggregate::intermediateType(
aggregate.call->name(), aggregate.rawInputTypes);

// Setup aggregation mask: convert the Variable Reference name to the
// channel (projection) index, if there is a mask.
if (const auto& mask = aggregate.mask) {
if (mask != nullptr) {
info.mask = inputType->asRow().getChildIdx(mask->name());
}
}

const auto& resultType = outputType_->childAt(numHashers + i);
info.function = Aggregate::create(
aggregate.call->name(),
isPartialOutput(aggregationNode_->step())
? core::AggregationNode::Step::kPartial
: core::AggregationNode::Step::kSingle,
aggregate.rawInputTypes,
resultType,
operatorCtx_->driverCtx()->queryConfig());

auto lambdas = extractLambdaInputs(aggregate);
if (!lambdas.empty()) {
if (expressionEvaluator == nullptr) {
expressionEvaluator = std::make_shared<SimpleExpressionEvaluator>(
operatorCtx_->execCtx()->queryCtx(),
operatorCtx_->execCtx()->pool());
}
info.function->setLambdaExpressions(lambdas, expressionEvaluator);
}

info.output = numHashers + i;

// Sorting keys and orders.
const auto numSortingKeys = aggregate.sortingKeys.size();
VELOX_CHECK_EQ(numSortingKeys, aggregate.sortingOrders.size());
info.sortingOrders = aggregate.sortingOrders;

info.sortingKeys.reserve(numSortingKeys);
for (const auto& key : aggregate.sortingKeys) {
info.sortingKeys.push_back(exprToChannel(key.get(), inputType));
}

aggregateInfos.emplace_back(std::move(info));
}
std::vector<AggregateInfo> aggregateInfos = AggregateUtil::toAggregateInfo(
*aggregationNode_, *operatorCtx_, numHashers, expressionEvaluator);

// Check that aggregate result type match the output type.
for (auto i = 0; i < aggregateInfos.size(); i++) {
Expand Down
Loading
Loading