Skip to content

Commit

Permalink
[GLUTEN-2163][CH] support aggregate function approx_percentile (#4829)
Browse files Browse the repository at this point in the history
[CH] support aggregate function approx_percentile
  • Loading branch information
taiyang-li authored Mar 29, 2024
1 parent 2eff2dd commit c0bad12
Show file tree
Hide file tree
Showing 27 changed files with 381 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ case class CHHashAggregateExecTransformer(
fields = fields :+ (child.dataType, child.nullable)
}
(makeStructType(fields), false)
case approxPercentile: ApproximatePercentile =>
var fields = Seq[(DataType, Boolean)]()
// Use approxPercentile.nullable as the nullable of the struct type
// to make sure it returns null when input is empty
fields = fields :+ (approxPercentile.child.dataType, approxPercentile.nullable)
fields = fields :+ (approxPercentile.percentageExpression.dataType,
approxPercentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
spark.sql("drop table test_tbl_4997")
}

test("aggregate function approx_percentile") {
// single percentage
val sql1 = "select l_linenumber % 10, approx_percentile(l_extendedprice, 0.5) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql1)({ _ => })

// multiple percentages
val sql2 =
"select l_linenumber % 10, approx_percentile(l_extendedprice, array(0.1, 0.2, 0.3)) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql2)({ _ => })
}

test("GLUTEN-5096: Bug fix regexp_extract diff") {
val tbl_create_sql = "create table test_tbl_5096(id bigint, data string) using parquet"
val tbl_insert_sql = "insert into test_tbl_5096 values(1, 'abc'), (2, 'abc\n')"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFor

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, Literal, MakeYMInterval, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile, Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
Expand Down Expand Up @@ -335,8 +335,10 @@ object BackendSettings extends BackendSettingsApi {
case _ =>
}
windowExpression.windowFunction match {
case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist | _: DenseRank |
_: PercentRank | _: NthValue | _: NTile | _: Lag | _: Lead =>
case _: RowNumber | _: Rank | _: CumeDist | _: DenseRank | _: PercentRank |
_: NthValue | _: NTile | _: Lag | _: Lead =>
case aggrExpr: AggregateExpression
if !aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile] =>
case _ =>
allSupported = false
}
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Operator/DefaultHashAggregateResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace local_engine
{

/// Special case: goruping keys is empty, and there is no input from updstream, but still need to return one default row.
/// Special case: goruping keys is empty, and there is no input from upstream, but still need to return one default row.
class DefaultHashAggregateResultStep : public DB::ITransformingStep
{
public:
Expand Down
14 changes: 3 additions & 11 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,11 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag,
bool withNullability) const
bool with_nullability) const
{
const auto & output_type = func_info.output_type;
bool needToConvertNodeType = false;
if (withNullability)
{
needToConvertNodeType = !TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type);
}
else
{
needToConvertNodeType = !TypeParser::isTypeMatched(output_type, func_node->result_type);
}
if (needToConvertNodeType)
bool need_convert_type = !TypeParser::isTypeMatched(output_type, func_node->result_type, !with_nullability);
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
Expand Down
9 changes: 6 additions & 3 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class AggregateFunctionParser
/// In some special cases, different arguments size or different arguments types may refer to different
/// CH function implementation.
virtual String getCHFunctionName(const CommonFunctionInfo & func_info) const = 0;

/// In most cases, arguments size and types are enough to determine the CH function implementation.
/// This is only be used in SerializedPlanParser::parseNameStructure.
virtual String getCHFunctionName(const DB::DataTypes & args) const = 0;
/// It is only be used in TypeParser::buildBlockFromNamedStruct
/// Users are allowed to modify arg types to make it fit for ggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct
virtual String getCHFunctionName(DB::DataTypes & args) const = 0;

/// Do some preprojections for the function arguments, and return the necessary arguments for the CH function.
virtual DB::ActionsDAG::NodeRawConstPtrs
Expand All @@ -112,7 +114,8 @@ class AggregateFunctionParser
virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
DB::ActionsDAGPtr & actions_dag,
bool with_nullability) const;

/// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
Expand Down
42 changes: 33 additions & 9 deletions cpp-ch/local-engine/Parser/AggregateRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,22 @@ AggregateRelParser::AggregateRelParser(SerializedPlanParser * plan_paser_) : Rel
DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> &)
{
setup(std::move(query_plan), rel);

addPreProjection();
LOG_TRACE(logger, "header after pre-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
if (has_final_stage)
{
addMergingAggregatedStep();
LOG_TRACE(logger, "header after merging is: {}", plan->getCurrentDataStream().header.dumpStructure());

addPostProjection();
LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
}
else if (has_complete_stage)
{
addCompleteModeAggregatedStep();
LOG_TRACE(logger, "header after complete aggregate is: {}", plan->getCurrentDataStream().header.dumpStructure());

addPostProjection();
LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
}
Expand Down Expand Up @@ -184,6 +187,8 @@ void AggregateRelParser::addPreProjection()
}
if (projection_action->dumpDAG() != dag_footprint)
{
/// Avoid unnecessary evaluation
projection_action->removeUnusedActions();
auto projection_step = std::make_unique<DB::ExpressionStep>(plan->getCurrentDataStream(), projection_action);
projection_step->setStepDescription("Projection before aggregate");
steps.emplace_back(projection_step.get());
Expand All @@ -193,22 +198,41 @@ void AggregateRelParser::addPreProjection()

void AggregateRelParser::buildAggregateDescriptions(AggregateDescriptions & descriptions)
{
auto build_result_column_name = [](const String & function_name, const Strings & arg_column_names, substrait::AggregationPhase phase)
auto build_result_column_name = [](const String & function_name, const Array & params, const Strings & arg_names, substrait::AggregationPhase phase)
{
if (phase == substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT)
{
assert(arg_column_names.size() == 1);
return arg_column_names[0];
assert(arg_names.size() == 1);
return arg_names[0];
}

String result = function_name;
if (!params.empty())
{
result += "(";
for (size_t i = 0; i < params.size(); ++i)
{
if (i != 0)
result += ",";
result += toString(params[i]);
}
result += ")";
}
String arg_list_str = boost::algorithm::join(arg_column_names, ",");
return function_name + "(" + arg_list_str + ")";

result += "(";
result += boost::algorithm::join(arg_names, ",");
result += ")";
return result;
};

for (auto & agg_info : aggregates)
{
AggregateDescription description;
const auto & measure = agg_info.measure->measure();
description.column_name = build_result_column_name(agg_info.function_name, agg_info.arg_column_names, measure.phase());
description.column_name
= build_result_column_name(agg_info.function_name, agg_info.params, agg_info.arg_column_names, measure.phase());
agg_info.measure_column_name = description.column_name;
// std::cout << "description.column_name:" << description.column_name << std::endl;
description.argument_names = agg_info.arg_column_names;
DB::AggregateFunctionProperties properties;

Expand Down Expand Up @@ -259,7 +283,7 @@ void AggregateRelParser::addMergingAggregatedStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
Aggregator::Params params(
grouping_keys,
aggregate_descriptions,
Expand Down Expand Up @@ -298,7 +322,7 @@ void AggregateRelParser::addCompleteModeAggregatedStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);
if (enable_streaming_aggregating)
{
Expand Down Expand Up @@ -376,7 +400,7 @@ void AggregateRelParser::addAggregatingStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true);

if (enable_streaming_aggregating)
Expand Down
16 changes: 11 additions & 5 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#include <Common/MergeTreeTool.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include <Common/JNIUtils.h>

namespace DB
{
Expand Down Expand Up @@ -312,12 +313,17 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait
auto iter = rel.local_files().items().at(0).uri_file();
auto pos = iter.find(':');
auto iter_index = std::stoi(iter.substr(pos + 1, iter.size()));
jobject input_iter = input_iters[iter_index];
bool materialize_input = materialize_inputs[iter_index];

GET_JNIENV(env)
SCOPE_EXIT({CLEAN_JNIENV});
auto * first_block = SourceFromJavaIter::peekBlock(env, input_iter);

/// Try to decide header from the first block read from Java iterator. Thus AggregateFunction with parameters has more precise types.
auto header = first_block ? first_block->cloneEmpty() : TypeParser::buildBlockFromNamedStruct(rel.base_schema());
auto source = std::make_shared<SourceFromJavaIter>(context, std::move(header), input_iter, materialize_input, first_block);

auto source = std::make_shared<SourceFromJavaIter>(
context,
TypeParser::buildBlockFromNamedStruct(rel.base_schema()),
input_iters[iter_index],
materialize_inputs[iter_index]);
QueryPlanStepPtr source_step = std::make_unique<ReadFromPreparedSource>(Pipe(source));
source_step->setStepDescription("Read From Java Iter");
return source_step;
Expand Down
27 changes: 14 additions & 13 deletions cpp-ch/local-engine/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ DB::Block TypeParser::buildBlockFromNamedStruct(
auto tmp_ctx = DB::Context::createCopy(SerializedPlanParser::global_context);
SerializedPlanParser tmp_plan_parser(tmp_ctx);
auto function_parser = AggregateFunctionParserFactory::instance().get(name_parts[3], &tmp_plan_parser);
/// This may remove elements from args_types, because some of them are used to determine CH function name, but not needed for the following
/// call `AggregateFunctionFactory::instance().get`
auto agg_function_name = function_parser->getCHFunctionName(args_types);
auto action = NullsAction::EMPTY;
ch_type = AggregateFunctionFactory::instance()
Expand Down Expand Up @@ -316,21 +318,20 @@ DB::Block TypeParser::buildBlockFromNamedStructWithoutDFS(const substrait::Named
return res;
}

bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type)
bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const DataTypePtr & ch_type, bool ignore_nullability)
{
const auto parsed_ch_type = TypeParser::parseType(substrait_type);
// if it's only different in nullability, we consider them same.
// this will be problematic for some functions being not-null in spark but nullable in clickhouse.
// e.g. murmur3hash
const auto a = removeNullable(parsed_ch_type);
const auto b = removeNullable(ch_type);
return a->equals(*b);
}

bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type)
{
const auto parsed_ch_type = TypeParser::parseType(substrait_type);
return parsed_ch_type->equals(*ch_type);
if (ignore_nullability)
{
// if it's only different in nullability, we consider them same.
// this will be problematic for some functions being not-null in spark but nullable in clickhouse.
// e.g. murmur3hash
const auto a = removeNullable(parsed_ch_type);
const auto b = removeNullable(ch_type);
return a->equals(*b);
}
else
return parsed_ch_type->equals(*ch_type);
}

DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type)
Expand Down
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Parser/TypeParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ namespace local_engine
/// Build block from substrait NamedStruct without DFS rules, different from buildBlockFromNamedStruct
static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct& struct_);

static bool isTypeMatched(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type);
static bool isTypeMatchedWithNullability(const substrait::Type& substrait_type, const DB::DataTypePtr& ch_type);
static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type, bool ignore_nullability = true);

private:
/// Mapping spark type names to CH type names.
Expand Down
7 changes: 4 additions & 3 deletions cpp-ch/local-engine/Parser/WindowRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ WindowRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & re
for (auto & it : window_descriptions)
{
auto & win = it.second;

auto window_step = std::make_unique<DB::WindowStep>(current_plan->getCurrentDataStream(), win, win.window_functions, false);
window_step->setStepDescription("Window step for window '" + win.window_name + "'");
steps.emplace_back(window_step.get());
Expand Down Expand Up @@ -328,13 +328,14 @@ void WindowRelParser::tryAddProjectionBeforeWindow()
for (auto & win_info : win_infos )
{
auto arg_nodes = win_info.function_parser->parseFunctionArguments(win_info.parser_func_info, actions_dag);
// This may remove elements from arg_nodes, because some of them are converted to CH func parameters.
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes);
for (auto & arg_node : arg_nodes)
{
win_info.arg_column_names.emplace_back(arg_node->result_name);
win_info.arg_column_types.emplace_back(arg_node->result_type);
actions_dag->addOrReplaceInOutputs(*arg_node);
}
win_info.params = win_info.function_parser->parseFunctionParameters(win_info.parser_func_info, arg_nodes);
}
}

if (actions_dag->dumpDAG() != dag_footprint)
Expand Down
Loading

0 comments on commit c0bad12

Please sign in to comment.