From c0bad12820961c5f3650af4c247d0fe1f270c919 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Fri, 29 Mar 2024 14:11:26 +0800 Subject: [PATCH] [GLUTEN-2163][CH] support aggregate function approx_percentile (#4829) [CH] support aggregate function approx_percentile --- .../CHHashAggregateExecTransformer.scala | 8 + ...enClickHouseTPCHSaltNullParquetSuite.scala | 13 ++ .../backendsapi/velox/VeloxBackend.scala | 8 +- .../Operator/DefaultHashAggregateResult.h | 2 +- .../Parser/AggregateFunctionParser.cpp | 14 +- .../Parser/AggregateFunctionParser.h | 9 +- .../Parser/AggregateRelParser.cpp | 42 ++++-- .../Parser/SerializedPlanParser.cpp | 16 +- cpp-ch/local-engine/Parser/TypeParser.cpp | 27 ++-- cpp-ch/local-engine/Parser/TypeParser.h | 3 +- .../local-engine/Parser/WindowRelParser.cpp | 7 +- .../ApproxPercentileParser.cpp | 140 ++++++++++++++++++ .../ApproxPercentileParser.h | 48 ++++++ .../BloomFilterAggParser.h | 4 +- .../CollectListParser.h | 8 +- .../CommonAggregateFunctionParser.h | 2 +- .../aggregate_function_parser/CountParser.cpp | 7 +- .../aggregate_function_parser/CountParser.h | 2 +- .../aggregate_function_parser/LeadLagParser.h | 4 +- .../Storages/SourceFromJavaIter.cpp | 78 ++++++---- .../Storages/SourceFromJavaIter.h | 11 +- .../SubstraitSource/ReadBufferBuilder.cpp | 27 ++-- .../expression/ExpressionMappings.scala | 3 +- .../clickhouse/ClickHouseTestSettings.scala | 1 + .../clickhouse/ClickHouseTestSettings.scala | 1 + .../clickhouse/ClickHouseTestSettings.scala | 1 + .../expression/ExpressionNames.scala | 1 + 27 files changed, 381 insertions(+), 106 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp create mode 100644 cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index 13d02ffa7555..8a312c456851 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -365,6 +365,14 @@ case class CHHashAggregateExecTransformer( fields = fields :+ (child.dataType, child.nullable) } (makeStructType(fields), false) + case approxPercentile: ApproximatePercentile => + var fields = Seq[(DataType, Boolean)]() + // Use approxPercentile.nullable as the nullable of the struct type + // to make sure it returns null when input is empty + fields = fields :+ (approxPercentile.child.dataType, approxPercentile.nullable) + fields = fields :+ (approxPercentile.percentageExpression.dataType, + approxPercentile.percentageExpression.nullable) + (makeStructType(fields), attr.nullable) case _ => (makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 5c9d31a01d20..098b0d9e2f8a 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2465,6 +2465,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_4997") } + test("aggregate function approx_percentile") { + // single percentage + val sql1 = "select l_linenumber % 10, approx_percentile(l_extendedprice, 0.5) " + + "from lineitem group by l_linenumber % 10" + runQueryAndCompare(sql1)({ _ => }) + + // multiple percentages + val sql2 = + "select l_linenumber % 10, approx_percentile(l_extendedprice, array(0.1, 0.2, 0.3)) " + + "from lineitem group by l_linenumber % 10" + runQueryAndCompare(sql2)({ _ => }) + } + test("GLUTEN-5096: Bug fix regexp_extract diff") { val tbl_create_sql = "create table test_tbl_5096(id bigint, data string) using parquet" val tbl_insert_sql = "insert into test_tbl_5096 values(1, 'abc'), (2, 'abc\n')" diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index 25d389f3cb95..e27b8084e0bd 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -28,7 +28,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFor import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, Literal, MakeYMInterval, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Count, Sum} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} @@ -335,8 +335,10 @@ object BackendSettings extends BackendSettingsApi { case _ => } windowExpression.windowFunction match { - case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist | _: DenseRank | - _: PercentRank | _: NthValue | _: NTile | _: Lag | _: Lead => + case _: RowNumber | _: Rank | _: CumeDist | _: DenseRank | _: PercentRank | + _: NthValue | _: NTile | _: Lag | _: Lead => + case aggrExpr: AggregateExpression + if !aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile] => case _ => allSupported = false } diff --git a/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h b/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h index 5729a4770789..b433f4072992 100644 --- a/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h +++ b/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h @@ -22,7 +22,7 @@ namespace local_engine { -/// Special case: goruping keys is empty, and there is no input from updstream, but still need to return one default row. +/// Special case: goruping keys is empty, and there is no input from upstream, but still need to return one default row. class DefaultHashAggregateResultStep : public DB::ITransformingStep { public: diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp index 708a169145a3..56cd58ad9eab 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp @@ -148,19 +148,11 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded( const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, - bool withNullability) const + bool with_nullability) const { const auto & output_type = func_info.output_type; - bool needToConvertNodeType = false; - if (withNullability) - { - needToConvertNodeType = !TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type); - } - else - { - needToConvertNodeType = !TypeParser::isTypeMatched(output_type, func_node->result_type); - } - if (needToConvertNodeType) + bool need_convert_type = !TypeParser::isTypeMatched(output_type, func_node->result_type, !with_nullability); + if (need_convert_type) { func_node = ActionsDAGUtil::convertNodeType( actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name); diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h index e2444361f310..464ad099a3b6 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h @@ -89,9 +89,11 @@ class AggregateFunctionParser /// In some special cases, different arguments size or different arguments types may refer to different /// CH function implementation. virtual String getCHFunctionName(const CommonFunctionInfo & func_info) const = 0; + /// In most cases, arguments size and types are enough to determine the CH function implementation. - /// This is only be used in SerializedPlanParser::parseNameStructure. - virtual String getCHFunctionName(const DB::DataTypes & args) const = 0; + /// It is only be used in TypeParser::buildBlockFromNamedStruct + /// Users are allowed to modify arg types to make it fit for ggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct + virtual String getCHFunctionName(DB::DataTypes & args) const = 0; /// Do some preprojections for the function arguments, and return the necessary arguments for the CH function. virtual DB::ActionsDAG::NodeRawConstPtrs @@ -112,7 +114,8 @@ class AggregateFunctionParser virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, - DB::ActionsDAGPtr & actions_dag, bool withNullability) const; + DB::ActionsDAGPtr & actions_dag, + bool with_nullability) const; /// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x). /// 0.5 is the parameter of percentiles function. diff --git a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp index a3ab329f0af1..d20f30e41191 100644 --- a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp @@ -54,12 +54,14 @@ AggregateRelParser::AggregateRelParser(SerializedPlanParser * plan_paser_) : Rel DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list &) { setup(std::move(query_plan), rel); + addPreProjection(); LOG_TRACE(logger, "header after pre-projection is: {}", plan->getCurrentDataStream().header.dumpStructure()); if (has_final_stage) { addMergingAggregatedStep(); LOG_TRACE(logger, "header after merging is: {}", plan->getCurrentDataStream().header.dumpStructure()); + addPostProjection(); LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure()); } @@ -67,6 +69,7 @@ DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const su { addCompleteModeAggregatedStep(); LOG_TRACE(logger, "header after complete aggregate is: {}", plan->getCurrentDataStream().header.dumpStructure()); + addPostProjection(); LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure()); } @@ -184,6 +187,8 @@ void AggregateRelParser::addPreProjection() } if (projection_action->dumpDAG() != dag_footprint) { + /// Avoid unnecessary evaluation + projection_action->removeUnusedActions(); auto projection_step = std::make_unique(plan->getCurrentDataStream(), projection_action); projection_step->setStepDescription("Projection before aggregate"); steps.emplace_back(projection_step.get()); @@ -193,22 +198,41 @@ void AggregateRelParser::addPreProjection() void AggregateRelParser::buildAggregateDescriptions(AggregateDescriptions & descriptions) { - auto build_result_column_name = [](const String & function_name, const Strings & arg_column_names, substrait::AggregationPhase phase) + auto build_result_column_name = [](const String & function_name, const Array & params, const Strings & arg_names, substrait::AggregationPhase phase) { if (phase == substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT) { - assert(arg_column_names.size() == 1); - return arg_column_names[0]; + assert(arg_names.size() == 1); + return arg_names[0]; + } + + String result = function_name; + if (!params.empty()) + { + result += "("; + for (size_t i = 0; i < params.size(); ++i) + { + if (i != 0) + result += ","; + result += toString(params[i]); + } + result += ")"; } - String arg_list_str = boost::algorithm::join(arg_column_names, ","); - return function_name + "(" + arg_list_str + ")"; + + result += "("; + result += boost::algorithm::join(arg_names, ","); + result += ")"; + return result; }; + for (auto & agg_info : aggregates) { AggregateDescription description; const auto & measure = agg_info.measure->measure(); - description.column_name = build_result_column_name(agg_info.function_name, agg_info.arg_column_names, measure.phase()); + description.column_name + = build_result_column_name(agg_info.function_name, agg_info.params, agg_info.arg_column_names, measure.phase()); agg_info.measure_column_name = description.column_name; + // std::cout << "description.column_name:" << description.column_name << std::endl; description.argument_names = agg_info.arg_column_names; DB::AggregateFunctionProperties properties; @@ -259,7 +283,7 @@ void AggregateRelParser::addMergingAggregatedStep() { AggregateDescriptions aggregate_descriptions; buildAggregateDescriptions(aggregate_descriptions); - auto settings = getContext()->getSettingsRef(); + const auto & settings = getContext()->getSettingsRef(); Aggregator::Params params( grouping_keys, aggregate_descriptions, @@ -298,7 +322,7 @@ void AggregateRelParser::addCompleteModeAggregatedStep() { AggregateDescriptions aggregate_descriptions; buildAggregateDescriptions(aggregate_descriptions); - auto settings = getContext()->getSettingsRef(); + const auto & settings = getContext()->getSettingsRef(); bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true); if (enable_streaming_aggregating) { @@ -376,7 +400,7 @@ void AggregateRelParser::addAggregatingStep() { AggregateDescriptions aggregate_descriptions; buildAggregateDescriptions(aggregate_descriptions); - auto settings = getContext()->getSettingsRef(); + const auto & settings = getContext()->getSettingsRef(); bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true); if (enable_streaming_aggregating) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index f0715f5009a4..1b7f0448f1e0 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -89,6 +89,7 @@ #include #include #include +#include namespace DB { @@ -312,12 +313,17 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait auto iter = rel.local_files().items().at(0).uri_file(); auto pos = iter.find(':'); auto iter_index = std::stoi(iter.substr(pos + 1, iter.size())); + jobject input_iter = input_iters[iter_index]; + bool materialize_input = materialize_inputs[iter_index]; + + GET_JNIENV(env) + SCOPE_EXIT({CLEAN_JNIENV}); + auto * first_block = SourceFromJavaIter::peekBlock(env, input_iter); + + /// Try to decide header from the first block read from Java iterator. Thus AggregateFunction with parameters has more precise types. + auto header = first_block ? first_block->cloneEmpty() : TypeParser::buildBlockFromNamedStruct(rel.base_schema()); + auto source = std::make_shared(context, std::move(header), input_iter, materialize_input, first_block); - auto source = std::make_shared( - context, - TypeParser::buildBlockFromNamedStruct(rel.base_schema()), - input_iters[iter_index], - materialize_inputs[iter_index]); QueryPlanStepPtr source_step = std::make_unique(Pipe(source)); source_step->setStepDescription("Read From Java Iter"); return source_step; diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp b/cpp-ch/local-engine/Parser/TypeParser.cpp index 12c23e6060bd..3ad19bb2bd73 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.cpp +++ b/cpp-ch/local-engine/Parser/TypeParser.cpp @@ -280,6 +280,8 @@ DB::Block TypeParser::buildBlockFromNamedStruct( auto tmp_ctx = DB::Context::createCopy(SerializedPlanParser::global_context); SerializedPlanParser tmp_plan_parser(tmp_ctx); auto function_parser = AggregateFunctionParserFactory::instance().get(name_parts[3], &tmp_plan_parser); + /// This may remove elements from args_types, because some of them are used to determine CH function name, but not needed for the following + /// call `AggregateFunctionFactory::instance().get` auto agg_function_name = function_parser->getCHFunctionName(args_types); auto action = NullsAction::EMPTY; ch_type = AggregateFunctionFactory::instance() @@ -316,21 +318,20 @@ DB::Block TypeParser::buildBlockFromNamedStructWithoutDFS(const substrait::Named return res; } -bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type) +bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type, bool ignore_nullability) { const auto parsed_ch_type = TypeParser::parseType(substrait_type); - // if it's only different in nullability, we consider them same. - // this will be problematic for some functions being not-null in spark but nullable in clickhouse. - // e.g. murmur3hash - const auto a = removeNullable(parsed_ch_type); - const auto b = removeNullable(ch_type); - return a->equals(*b); -} - -bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type) -{ - const auto parsed_ch_type = TypeParser::parseType(substrait_type); - return parsed_ch_type->equals(*ch_type); + if (ignore_nullability) + { + // if it's only different in nullability, we consider them same. + // this will be problematic for some functions being not-null in spark but nullable in clickhouse. + // e.g. murmur3hash + const auto a = removeNullable(parsed_ch_type); + const auto b = removeNullable(ch_type); + return a->equals(*b); + } + else + return parsed_ch_type->equals(*ch_type); } DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type) diff --git a/cpp-ch/local-engine/Parser/TypeParser.h b/cpp-ch/local-engine/Parser/TypeParser.h index c687c3024aa8..666effff45d3 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.h +++ b/cpp-ch/local-engine/Parser/TypeParser.h @@ -49,8 +49,7 @@ namespace local_engine /// Build block from substrait NamedStruct without DFS rules, different from buildBlockFromNamedStruct static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct& struct_); - static bool isTypeMatched(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type); - static bool isTypeMatchedWithNullability(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type); + static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type, bool ignore_nullability = true); private: /// Mapping spark type names to CH type names. diff --git a/cpp-ch/local-engine/Parser/WindowRelParser.cpp b/cpp-ch/local-engine/Parser/WindowRelParser.cpp index a1787a2c93c5..4125879b5ec7 100644 --- a/cpp-ch/local-engine/Parser/WindowRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowRelParser.cpp @@ -83,7 +83,7 @@ WindowRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & re for (auto & it : window_descriptions) { auto & win = it.second; - + auto window_step = std::make_unique(current_plan->getCurrentDataStream(), win, win.window_functions, false); window_step->setStepDescription("Window step for window '" + win.window_name + "'"); steps.emplace_back(window_step.get()); @@ -328,13 +328,14 @@ void WindowRelParser::tryAddProjectionBeforeWindow() for (auto & win_info : win_infos ) { auto arg_nodes = win_info.function_parser->parseFunctionArguments(win_info.parser_func_info, actions_dag); + // This may remove elements from arg_nodes, because some of them are converted to CH func parameters. + win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes); for (auto & arg_node : arg_nodes) { win_info.arg_column_names.emplace_back(arg_node->result_name); win_info.arg_column_types.emplace_back(arg_node->result_type); actions_dag->addOrReplaceInOutputs(*arg_node); - } - win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes); + } } if (actions_dag->dumpDAG() != dag_footprint) diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp new file mode 100644 index 000000000000..9a164de14509 --- /dev/null +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "substrait/algebra.pb.h" + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +void ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const +{ + if (size != expect) + throw Exception( + DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Function {} in phase {} requires exactly {} arguments but got {} arguments", + getName(), + magic_enum::enum_name(phase), + expect, + size); +} + +const substrait::Expression::Literal & +ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const +{ + if (!expr.has_literal()) + throw Exception( + DB::ErrorCodes::BAD_ARGUMENTS, + "The argument of function {} in phase {} must be literal, but is {}", + getName(), + magic_enum::enum_name(phase), + expr.DebugString()); + return expr.literal(); +} + +String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo & func_info) const +{ + const auto & output_type = func_info.output_type; + return output_type.has_list() ? "quantilesGK" : "quantileGK"; +} + +String ApproxPercentileParser::getCHFunctionName(DB::DataTypes & types) const +{ + /// Always invoked during second stage + assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), 2); + + auto type = removeNullable(types[1]); + types.resize(1); + return isArray(type) ? "quantilesGK" : "quantileGK"; +} + +DB::Array ApproxPercentileParser::parseFunctionParameters( + const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const +{ + if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE + || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT || func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED) + { + Array params; + const auto & arguments = func_info.arguments; + assertArgumentsSize(func_info.phase, arguments.size(), 3); + + const auto & accuracy_expr = arguments[2].value(); + const auto & accuracy_literal = assertAndGetLiteral(func_info.phase, accuracy_expr); + auto [type1, field1] = parseLiteral(accuracy_literal); + params.emplace_back(std::move(field1)); + + const auto & percentage_expr = arguments[1].value(); + const auto & percentage_literal = assertAndGetLiteral(func_info.phase, percentage_expr); + auto [type2, field2] = parseLiteral(percentage_literal); + if (isArray(type2)) + { + /// Multiple percentages for quantilesGK + const Array & percentags = field2.get(); + for (const auto & percentage : percentags) + params.emplace_back(percentage); + } + else + { + /// Single percentage for quantileGK + params.emplace_back(std::move(field2)); + } + + /// Delete percentage and accuracy argument for clickhouse compatiability + arg_nodes.resize(1); + return params; + } + else + { + assertArgumentsSize(func_info.phase, arg_nodes.size(), 1); + const auto & result_type = arg_nodes[0]->result_type; + const auto * aggregate_function_type = DB::checkAndGetDataType(result_type.get()); + if (!aggregate_function_type) + throw Exception( + DB::ErrorCodes::BAD_ARGUMENTS, + "The first argument type of function {} in phase {} must be AggregateFunction, but is {}", + getName(), + magic_enum::enum_name(func_info.phase), + result_type->getName()); + + return aggregate_function_type->getParameters(); + } +} + +DB::Array ApproxPercentileParser::getDefaultFunctionParameters() const +{ + return {10000, 1}; +} + + +static const AggregateFunctionParserRegister register_approx_percentile; +} diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h new file mode 100644 index 000000000000..37eae30457b4 --- /dev/null +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + + +/* +spark: approx_percentile(col, percentage [, accuracy]) +1. When percentage is an array literal, spark returns an array of percentiles, corresponding to CH: quantilesGK(accuracy, percentage[0], ...)(col) +1. Otherwise spark return a single percentile, corresponding to CH: quantileGK(accuracy, percentage)(col) +*/ + +namespace local_engine +{ +class ApproxPercentileParser : public AggregateFunctionParser +{ +public: + explicit ApproxPercentileParser(SerializedPlanParser * plan_parser_) : AggregateFunctionParser(plan_parser_) { } + ~ApproxPercentileParser() override = default; + String getName() const override { return name; } + static constexpr auto name = "approx_percentile"; + String getCHFunctionName(const CommonFunctionInfo & func_info) const override; + String getCHFunctionName(DB::DataTypes & types) const override; + + DB::Array + parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override; + + DB::Array getDefaultFunctionParameters() const override; + +private: + void assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const; + const substrait::Expression::Literal & assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const; +}; +} diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h index 4bb41ca6d4c2..2465164421f1 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h @@ -26,10 +26,10 @@ class AggregateFunctionParserBloomFilterAgg : public AggregateFunctionParser public: explicit AggregateFunctionParserBloomFilterAgg(SerializedPlanParser * plan_parser_) : AggregateFunctionParser(plan_parser_) { } ~AggregateFunctionParserBloomFilterAgg() override = default; - String getName() const override { return "bloom_filter_agg"; } + String getName() const override { return name; } static constexpr auto name = "bloom_filter_agg"; String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupBloomFilterState"; } - String getCHFunctionName(const DB ::DataTypes &) const override { return "groupBloomFilterState"; } + String getCHFunctionName(DB::DataTypes &) const override { return "groupBloomFilterState"; } DB::Array parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override; diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h index d7a9c1a5c188..60e1b4eaedd3 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h @@ -47,12 +47,12 @@ class CollectFunctionParser : public AggregateFunctionParser throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement"); } - virtual String getCHFunctionName(const DB::DataTypes &) const override + virtual String getCHFunctionName(DB::DataTypes &) const override { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement"); } const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool /* withNullability */) const override + const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool /* with_nullability */) const override { const DB::ActionsDAG::Node * ret_node = func_node; if (func_node->result_type->isNullable()) @@ -79,7 +79,7 @@ class CollectListParser : public CollectFunctionParser static constexpr auto name = "collect_list"; String getName() const override { return name; } String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupArray"; } - String getCHFunctionName(const DB::DataTypes &) const override { return "groupArray"; } + String getCHFunctionName(DB::DataTypes &) const override { return "groupArray"; } }; class CollectSetParser : public CollectFunctionParser @@ -90,6 +90,6 @@ class CollectSetParser : public CollectFunctionParser static constexpr auto name = "collect_set"; String getName() const override { return name; } String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupUniqArray"; } - String getCHFunctionName(const DB::DataTypes &) const override { return "groupUniqArray"; } + String getCHFunctionName(DB::DataTypes &) const override { return "groupUniqArray"; } }; } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.h index c486e16e1c69..e21581e00e3a 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.h @@ -31,7 +31,7 @@ namespace local_engine String getName() const override { return #substait_name; } \ static constexpr auto name = #substait_name; \ String getCHFunctionName(const CommonFunctionInfo &) const override { return #ch_name; } \ - String getCHFunctionName(const DB::DataTypes &) const override { return #ch_name; } \ + String getCHFunctionName(DB::DataTypes &) const override { return #ch_name; } \ }; \ static const AggregateFunctionParserRegister register_##cls_name = AggregateFunctionParserRegister(); diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp index 47c9c38f7f9f..6135546f2e0f 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.cpp @@ -34,10 +34,12 @@ String CountParser::getCHFunctionName(const CommonFunctionInfo &) const { return "count"; } -String CountParser::getCHFunctionName(const DB::DataTypes &) const + +String CountParser::getCHFunctionName(DB::DataTypes &) const { return "count"; } + DB::ActionsDAG::NodeRawConstPtrs CountParser::parseFunctionArguments( const CommonFunctionInfo & func_info, const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const { @@ -45,7 +47,7 @@ DB::ActionsDAG::NodeRawConstPtrs CountParser::parseFunctionArguments( { throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} requires at least one argument", getName()); } - + const DB::ActionsDAG::Node * last_arg_node = nullptr; if (func_info.arguments.size() == 1) { @@ -82,5 +84,6 @@ DB::ActionsDAG::NodeRawConstPtrs CountParser::parseFunctionArguments( } return {last_arg_node}; } + static const AggregateFunctionParserRegister register_count; } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h index 7f34112137a6..a561f87d940d 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CountParser.h @@ -28,7 +28,7 @@ class CountParser : public AggregateFunctionParser static constexpr auto name = "count"; String getName() const override { return name; } String getCHFunctionName(const CommonFunctionInfo &) const override; - String getCHFunctionName(const DB::DataTypes &) const override; + String getCHFunctionName(DB::DataTypes &) const override; DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; }; diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h index 35afbfc47587..4fa1c1bbca13 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.h @@ -27,7 +27,7 @@ class LeadParser : public AggregateFunctionParser static constexpr auto name = "lead"; String getName() const override { return name; } String getCHFunctionName(const CommonFunctionInfo &) const override { return "leadInFrame"; } - String getCHFunctionName(const DB::DataTypes &) const override { return "leadInFrame"; } + String getCHFunctionName(DB::DataTypes &) const override { return "leadInFrame"; } DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; }; @@ -40,7 +40,7 @@ class LagParser : public AggregateFunctionParser static constexpr auto name = "lag"; String getName() const override { return name; } String getCHFunctionName(const CommonFunctionInfo &) const override { return "lagInFrame"; } - String getCHFunctionName(const DB::DataTypes &) const override { return "lagInFrame"; } + String getCHFunctionName(DB::DataTypes &) const override { return "lagInFrame"; } DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override; }; diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index d67161623af2..54d1d253e539 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -49,47 +49,73 @@ static DB::Block getRealHeader(const DB::Block & header) return BlockUtil::buildRowCountHeader(); } -SourceFromJavaIter::SourceFromJavaIter(DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool materialize_input_) + +DB::Block * SourceFromJavaIter::peekBlock(JNIEnv * env, jobject java_iter) +{ + jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext); + if (has_next) + { + jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); + return reinterpret_cast(byteArrayToLong(env, block)); + } + return nullptr; +} + + +SourceFromJavaIter::SourceFromJavaIter( + DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool materialize_input_, DB::Block * first_block_) : DB::ISource(getRealHeader(header)) + , context(context_) + , original_header(header) , java_iter(java_iter_) , materialize_input(materialize_input_) - , original_header(header) - , context(context_) + , first_block(first_block_) { } DB::Chunk SourceFromJavaIter::generate() { GET_JNIENV(env) - jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext); + SCOPE_EXIT({CLEAN_JNIENV}); + DB::Chunk result; - if (has_next) + DB::Block * data = nullptr; + if (first_block) [[unlikely]] + { + data = first_block; + first_block = nullptr; + } + else if (jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext)) { jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); - DB::Block * data = reinterpret_cast(byteArrayToLong(env, block)); - if (materialize_input) - materializeBlockInplace(*data); - if (data->rows() > 0) + data = reinterpret_cast(byteArrayToLong(env, block)); + } + else + return {}; + + /// Post-processing + if (materialize_input) + materializeBlockInplace(*data); + + if (data->rows() > 0) + { + size_t rows = data->rows(); + if (original_header.columns()) { - size_t rows = data->rows(); - if (original_header.columns()) - { - result.setColumns(data->mutateColumns(), rows); - convertNullable(result); - auto info = std::make_shared(); - info->is_overflows = data->info.is_overflows; - info->bucket_num = data->info.bucket_num; - result.setChunkInfo(info); - } - else - { - result = BlockUtil::buildRowCountChunk(rows); - auto info = std::make_shared(); - result.setChunkInfo(info); - } + result.setColumns(data->mutateColumns(), rows); + convertNullable(result); + auto info = std::make_shared(); + info->is_overflows = data->info.is_overflows; + info->bucket_num = data->info.bucket_num; + result.setChunkInfo(info); + } + else + { + result = BlockUtil::buildRowCountChunk(rows); + auto info = std::make_shared(); + result.setChunkInfo(info); } } - CLEAN_JNIENV return result; } diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h index e5cc601cff1e..6ee02e7480a0 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h @@ -31,7 +31,9 @@ class SourceFromJavaIter : public DB::ISource static Int64 byteArrayToLong(JNIEnv * env, jbyteArray arr); - SourceFromJavaIter(DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool materialize_input_); + static DB::Block * peekBlock(JNIEnv * env, jobject java_iter); + + SourceFromJavaIter(DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool materialize_input_, DB::Block * peek_block_); ~SourceFromJavaIter() override; String getName() const override { return "SourceFromJavaIter"; } @@ -41,10 +43,13 @@ class SourceFromJavaIter : public DB::ISource void convertNullable(DB::Chunk & chunk); DB::ColumnPtr convertNestedNullable(const DB::ColumnPtr & column, const DB::DataTypePtr & target_type); - jobject java_iter; - bool materialize_input; DB::ContextPtr context; DB::Block original_header; + jobject java_iter; + bool materialize_input; + + /// The first block read from java iteration to decide exact types of columns, especially for AggregateFunctions with parameters. + DB::Block * first_block = nullptr; }; } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp index 4f60fc54daa1..21640fe490c7 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp @@ -14,7 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "config.h" + #include +#include +#include #include #include #include @@ -22,9 +27,13 @@ #include #include #include +#include #include #include #include +#include +#include +#include #include #include #include @@ -32,22 +41,17 @@ #include #include #include - -#include -#include -#include "IO/ReadSettings.h" - +#include #include +#include #include +#include +#include #include #include #include #include -#include -#include -#include - #if USE_AWS_S3 #include #include @@ -55,11 +59,6 @@ #include #endif -#include - -#include -#include -#include namespace DB { diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index 42619b407fe6..e9133b6f228f 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -276,7 +276,8 @@ object ExpressionMappings { Sig[CovSample](COVAR_SAMP), Sig[Last](LAST), Sig[First](FIRST), - Sig[Skewness](SKEWNESS) + Sig[Skewness](SKEWNESS), + Sig[ApproximatePercentile](APPROX_PERCENTILE) ) ++ SparkShimLoader.getSparkShims.aggregateExpressionMappings /** Mapping Spark window expression to Substrait function name */ diff --git a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala index c9aaaee2852f..bd2cb26abf56 100644 --- a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala @@ -230,6 +230,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-22520: support code generation for large CaseWhen") .exclude("SPARK-24165: CaseWhen/If - nullability of nested types") .exclude("SPARK-27671: Fix analysis exception when casting null in nested field in struct") + .exclude("summary") .excludeGlutenTest("distributeBy and localSort") .excludeGlutenTest("describe") .excludeGlutenTest("Allow leading/trailing whitespace in string before casting") diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala index d000dbecb367..0194518959b4 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala @@ -252,6 +252,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-22520: support code generation for large CaseWhen") .exclude("SPARK-24165: CaseWhen/If - nullability of nested types") .exclude("SPARK-27671: Fix analysis exception when casting null in nested field in struct") + .exclude("summary") .excludeGlutenTest("distributeBy and localSort") .excludeGlutenTest("describe") .excludeGlutenTest("Allow leading/trailing whitespace in string before casting") diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala index 9fabee3ca705..63dd2f5476fb 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala @@ -250,6 +250,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-22520: support code generation for large CaseWhen") .exclude("SPARK-24165: CaseWhen/If - nullability of nested types") .exclude("SPARK-27671: Fix analysis exception when casting null in nested field in struct") + .exclude("summary") .excludeGlutenTest("distributeBy and localSort") .excludeGlutenTest("describe") .excludeGlutenTest("Allow leading/trailing whitespace in string before casting") diff --git a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala index 56e10741a20c..f61aa3161662 100644 --- a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala @@ -46,6 +46,7 @@ object ExpressionNames { final val FIRST_IGNORE_NULL = "first_ignore_null" final val APPROX_DISTINCT = "approx_distinct" final val SKEWNESS = "skewness" + final val APPROX_PERCENTILE = "approx_percentile" // Function names used by Substrait plan. final val ADD = "add"