Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-2163][CH] support aggregate function approx_percentile #4829

Merged
merged 11 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ case class CHHashAggregateExecTransformer(
fields = fields :+ (child.dataType, child.nullable)
}
(makeStructType(fields), false)
case approxPercentile: ApproximatePercentile =>
var fields = Seq[(DataType, Boolean)]()
// Use approxPercentile.nullable as the nullable of the struct type
// to make sure it returns null when input is empty
fields = fields :+ (approxPercentile.child.dataType, approxPercentile.nullable)
fields = fields :+ (approxPercentile.percentageExpression.dataType,
approxPercentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
spark.sql("drop table test_tbl_4997")
}

test("aggregate function approx_percentile") {
// single percentage
val sql1 = "select l_linenumber % 10, approx_percentile(l_extendedprice, 0.5) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql1)({ _ => })

// multiple percentages
val sql2 =
"select l_linenumber % 10, approx_percentile(l_extendedprice, array(0.1, 0.2, 0.3)) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql2)({ _ => })
}

test("GLUTEN-5096: Bug fix regexp_extract diff") {
val tbl_create_sql = "create table test_tbl_5096(id bigint, data string) using parquet"
val tbl_insert_sql = "insert into test_tbl_5096 values(1, 'abc'), (2, 'abc\n')"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFor

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, Literal, MakeYMInterval, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
Expand Down Expand Up @@ -335,8 +335,10 @@ object BackendSettings extends BackendSettingsApi {
case _ =>
}
windowExpression.windowFunction match {
case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist | _: DenseRank |
_: PercentRank | _: NthValue | _: NTile | _: Lag | _: Lead =>
case _: RowNumber | _: Rank | _: CumeDist | _: DenseRank | _: PercentRank |
_: NthValue | _: NTile | _: Lag | _: Lead =>
case aggrExpr: AggregateExpression
if !aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile] =>
taiyang-li marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
allSupported = false
}
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace local_engine
{

/// Special case: goruping keys is empty, and there is no input from updstream, but still need to return one default row.
/// Special case: goruping keys is empty, and there is no input from upstream, but still need to return one default row.
class DefaultHashAggregateResultStep : public DB::ITransformingStep
{
public:
Expand Down
14 changes: 3 additions & 11 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,11 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag,
bool withNullability) const
bool with_nullability) const
{
const auto & output_type = func_info.output_type;
bool needToConvertNodeType = false;
if (withNullability)
{
needToConvertNodeType = !TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type);
}
else
{
needToConvertNodeType = !TypeParser::isTypeMatched(output_type, func_node->result_type);
}
if (needToConvertNodeType)
bool need_convert_type = !TypeParser::isTypeMatched(output_type, func_node->result_type, !with_nullability);
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
Expand Down
9 changes: 6 additions & 3 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class AggregateFunctionParser
/// In some special cases, different arguments size or different arguments types may refer to different
/// CH function implementation.
virtual String getCHFunctionName(const CommonFunctionInfo & func_info) const = 0;

/// In most cases, arguments size and types are enough to determine the CH function implementation.
/// This is only be used in SerializedPlanParser::parseNameStructure.
virtual String getCHFunctionName(const DB::DataTypes & args) const = 0;
/// It is only be used in TypeParser::buildBlockFromNamedStruct
/// Users are allowed to modify arg types to make it fit for ggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct
virtual String getCHFunctionName(DB::DataTypes & args) const = 0;

/// Do some preprojections for the function arguments, and return the necessary arguments for the CH function.
virtual DB::ActionsDAG::NodeRawConstPtrs
Expand All @@ -112,7 +114,8 @@ class AggregateFunctionParser
virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
DB::ActionsDAGPtr & actions_dag,
bool with_nullability) const;

/// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
Expand Down
42 changes: 33 additions & 9 deletions cpp-ch/local-engine/Parser/AggregateRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,22 @@ AggregateRelParser::AggregateRelParser(SerializedPlanParser * plan_paser_) : Rel
DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> &)
{
setup(std::move(query_plan), rel);

addPreProjection();
LOG_TRACE(logger, "header after pre-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
if (has_final_stage)
{
addMergingAggregatedStep();
LOG_TRACE(logger, "header after merging is: {}", plan->getCurrentDataStream().header.dumpStructure());

addPostProjection();
LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
}
else if (has_complete_stage)
{
addCompleteModeAggregatedStep();
LOG_TRACE(logger, "header after complete aggregate is: {}", plan->getCurrentDataStream().header.dumpStructure());

addPostProjection();
LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
}
Expand Down Expand Up @@ -184,6 +187,8 @@ void AggregateRelParser::addPreProjection()
}
if (projection_action->dumpDAG() != dag_footprint)
{
/// Avoid unnecessary evaluation
projection_action->removeUnusedActions();
auto projection_step = std::make_unique<DB::ExpressionStep>(plan->getCurrentDataStream(), projection_action);
projection_step->setStepDescription("Projection before aggregate");
steps.emplace_back(projection_step.get());
Expand All @@ -193,22 +198,41 @@ void AggregateRelParser::addPreProjection()

void AggregateRelParser::buildAggregateDescriptions(AggregateDescriptions & descriptions)
{
auto build_result_column_name = [](const String & function_name, const Strings & arg_column_names, substrait::AggregationPhase phase)
auto build_result_column_name = [](const String & function_name, const Array & params, const Strings & arg_names, substrait::AggregationPhase phase)
{
if (phase == substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT)
{
assert(arg_column_names.size() == 1);
return arg_column_names[0];
assert(arg_names.size() == 1);
return arg_names[0];
}

String result = function_name;
if (!params.empty())
{
result += "(";
for (size_t i = 0; i < params.size(); ++i)
{
if (i != 0)
result += ",";
result += toString(params[i]);
}
result += ")";
}
String arg_list_str = boost::algorithm::join(arg_column_names, ",");
return function_name + "(" + arg_list_str + ")";

result += "(";
result += boost::algorithm::join(arg_names, ",");
result += ")";
return result;
};

for (auto & agg_info : aggregates)
{
AggregateDescription description;
const auto & measure = agg_info.measure->measure();
description.column_name = build_result_column_name(agg_info.function_name, agg_info.arg_column_names, measure.phase());
description.column_name
= build_result_column_name(agg_info.function_name, agg_info.params, agg_info.arg_column_names, measure.phase());
agg_info.measure_column_name = description.column_name;
// std::cout << "description.column_name:" << description.column_name << std::endl;
description.argument_names = agg_info.arg_column_names;
DB::AggregateFunctionProperties properties;

Expand Down Expand Up @@ -259,7 +283,7 @@ void AggregateRelParser::addMergingAggregatedStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
Aggregator::Params params(
grouping_keys,
aggregate_descriptions,
Expand Down Expand Up @@ -298,7 +322,7 @@ void AggregateRelParser::addCompleteModeAggregatedStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);
if (enable_streaming_aggregating)
{
Expand Down Expand Up @@ -376,7 +400,7 @@ void AggregateRelParser::addAggregatingStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);

if (enable_streaming_aggregating)
Expand Down
16 changes: 11 additions & 5 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#include <Common/MergeTreeTool.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include <Common/JNIUtils.h>

namespace DB
{
Expand Down Expand Up @@ -312,12 +313,17 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait
auto iter = rel.local_files().items().at(0).uri_file();
auto pos = iter.find(':');
auto iter_index = std::stoi(iter.substr(pos + 1, iter.size()));
jobject input_iter = input_iters[iter_index];
bool materialize_input = materialize_inputs[iter_index];

GET_JNIENV(env)
SCOPE_EXIT({CLEAN_JNIENV});
auto * first_block = SourceFromJavaIter::peekBlock(env, input_iter);

/// Try to decide header from the first block read from Java iterator. Thus AggregateFunction with parameters has more precise types.
auto header = first_block ? first_block->cloneEmpty() : TypeParser::buildBlockFromNamedStruct(rel.base_schema());
taiyang-li marked this conversation as resolved.
Show resolved Hide resolved
auto source = std::make_shared<SourceFromJavaIter>(context, std::move(header), input_iter, materialize_input, first_block);

auto source = std::make_shared<SourceFromJavaIter>(
context,
TypeParser::buildBlockFromNamedStruct(rel.base_schema()),
input_iters[iter_index],
materialize_inputs[iter_index]);
QueryPlanStepPtr source_step = std::make_unique<ReadFromPreparedSource>(Pipe(source));
source_step->setStepDescription("Read From Java Iter");
return source_step;
Expand Down
27 changes: 14 additions & 13 deletions cpp-ch/local-engine/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ DB::Block TypeParser::buildBlockFromNamedStruct(
auto tmp_ctx = DB::Context::createCopy(SerializedPlanParser::global_context);
SerializedPlanParser tmp_plan_parser(tmp_ctx);
auto function_parser = AggregateFunctionParserFactory::instance().get(name_parts[3], &tmp_plan_parser);
/// This may remove elements from args_types, because some of them are used to determine CH function name, but not needed for the following
/// call `AggregateFunctionFactory::instance().get`
auto agg_function_name = function_parser->getCHFunctionName(args_types);
auto action = NullsAction::EMPTY;
ch_type = AggregateFunctionFactory::instance()
Expand Down Expand Up @@ -316,21 +318,20 @@ DB::Block TypeParser::buildBlockFromNamedStructWithoutDFS(const substrait::Named
return res;
}

bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type)
bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type, bool ignore_nullability)
{
const auto parsed_ch_type = TypeParser::parseType(substrait_type);
// if it's only different in nullability, we consider them same.
// this will be problematic for some functions being not-null in spark but nullable in clickhouse.
// e.g. murmur3hash
const auto a = removeNullable(parsed_ch_type);
const auto b = removeNullable(ch_type);
return a->equals(*b);
}

bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type)
{
const auto parsed_ch_type = TypeParser::parseType(substrait_type);
return parsed_ch_type->equals(*ch_type);
if (ignore_nullability)
{
// if it's only different in nullability, we consider them same.
// this will be problematic for some functions being not-null in spark but nullable in clickhouse.
// e.g. murmur3hash
const auto a = removeNullable(parsed_ch_type);
const auto b = removeNullable(ch_type);
return a->equals(*b);
}
else
return parsed_ch_type->equals(*ch_type);
}

DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type)
Expand Down
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Parser/TypeParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ namespace local_engine
/// Build block from substrait NamedStruct without DFS rules, different from buildBlockFromNamedStruct
static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct& struct_);

static bool isTypeMatched(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type);
static bool isTypeMatchedWithNullability(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type);
static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type, bool ignore_nullability = true);

private:
/// Mapping spark type names to CH type names.
Expand Down
7 changes: 4 additions & 3 deletions cpp-ch/local-engine/Parser/WindowRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ WindowRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & re
for (auto & it : window_descriptions)
{
auto & win = it.second;

auto window_step = std::make_unique<DB::WindowStep>(current_plan->getCurrentDataStream(), win, win.window_functions, false);
window_step->setStepDescription("Window step for window '" + win.window_name + "'");
steps.emplace_back(window_step.get());
Expand Down Expand Up @@ -328,13 +328,14 @@ void WindowRelParser::tryAddProjectionBeforeWindow()
for (auto & win_info : win_infos )
{
auto arg_nodes = win_info.function_parser->parseFunctionArguments(win_info.parser_func_info, actions_dag);
// This may remove elements from arg_nodes, because some of them are converted to CH func parameters.
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes);
for (auto & arg_node : arg_nodes)
{
win_info.arg_column_names.emplace_back(arg_node->result_name);
win_info.arg_column_types.emplace_back(arg_node->result_type);
actions_dag->addOrReplaceInOutputs(*arg_node);
}
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes);
}
}

if (actions_dag->dumpDAG() != dag_footprint)
Expand Down
Loading
Loading