Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Mar 14, 2024
1 parent 20deb8e commit 2a92e73
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ case class CHHashAggregateExecTransformer(
case approxPercentile: ApproximatePercentile =>
var fields = Seq[(DataType, Boolean)]()
fields = fields :+ (approxPercentile.child.dataType, approxPercentile.child.nullable)
fields = fields :+ (approxPercentile.percentageExpression.dataType,
approxPercentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2452,5 +2452,18 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
spark.sql("drop table test_tbl_4279")
}

test("aggregate function approx_percentile") {
// single percentage
val sql1 = "select l_linenumber % 10, approx_percentile(l_extendedprice, 0.5) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql1)({ _ => })

// multiple percentages
val sql2 =
"select l_linenumber % 10, approx_percentile(l_extendedprice, array(0.1, 0.2, 0.3)) " +
"from lineitem group by l_linenumber % 10"
runQueryAndCompare(sql2)({ _ => })
}
}
// scalastyle:on line.size.limit
6 changes: 4 additions & 2 deletions cpp-ch/local-engine/Parser/AggregateFunctionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class AggregateFunctionParser
/// In some special cases, different arguments size or different arguments types may refer to different
/// CH function implementation.
virtual String getCHFunctionName(const CommonFunctionInfo & func_info) const = 0;

/// In most cases, arguments size and types are enough to determine the CH function implementation.
/// This is only be used in SerializedPlanParser::parseNameStructure.
virtual String getCHFunctionName(const DB::DataTypes & args) const = 0;
/// It is only be used in TypeParser::buildBlockFromNamedStruct
/// Users are allowed to modify arg types to make it fit for ggregateFunctionFactory::instance().get(...) in TypeParser::buildBlockFromNamedStruct
virtual String getCHFunctionName(DB::DataTypes & args) const = 0;

/// Do some preprojections for the function arguments, and return the necessary arguments for the CH function.
virtual DB::ActionsDAG::NodeRawConstPtrs
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/AggregateRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ void AggregateRelParser::addMergingAggregatedStep()
{
AggregateDescriptions aggregate_descriptions;
buildAggregateDescriptions(aggregate_descriptions);
auto settings = getContext()->getSettingsRef();
const auto & settings = getContext()->getSettingsRef();
Aggregator::Params params(
grouping_keys,
aggregate_descriptions,
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & s
auto tmp_ctx = DB::Context::createCopy(SerializedPlanParser::global_context);
SerializedPlanParser tmp_plan_parser(tmp_ctx);
auto function_parser = AggregateFunctionParserFactory::instance().get(name_parts[3], &tmp_plan_parser);
/// This may remove elements from args_types, because some of them are used to determine CH function name, but not needed for the following
/// call `AggregateFunctionFactory::instance().get`
auto agg_function_name = function_parser->getCHFunctionName(args_types);
auto action = NullsAction::EMPTY;
ch_type = AggregateFunctionFactory::instance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo & func
return output_type.has_list() ? "quantilesGK" : "quantileGK";
}

String ApproxPercentileParser::getCHFunctionName(const DB::DataTypes & types) const
String ApproxPercentileParser::getCHFunctionName(DB::DataTypes & types) const
{
/// Always invoked during second stage
assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), 1);
assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), 2);

auto type = removeNullable(types[0]);
auto type = removeNullable(types[1]);
types.resize(1);
return isArray(type) ? "quantilesGK" : "quantileGK";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ApproxPercentileParser : public AggregateFunctionParser
String getName() const override { return name; }
static constexpr auto name = "approx_percentile";
String getCHFunctionName(const CommonFunctionInfo & func_info) const override;
String getCHFunctionName(const DB::DataTypes & types) const override;
String getCHFunctionName(DB::DataTypes & types) const override;

DB::Array
parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AggregateFunctionParserBloomFilterAgg : public AggregateFunctionParser
String getName() const override { return name; }
static constexpr auto name = "bloom_filter_agg";
String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupBloomFilterState"; }
String getCHFunctionName(const DB ::DataTypes &) const override { return "groupBloomFilterState"; }
String getCHFunctionName(DB::DataTypes &) const override { return "groupBloomFilterState"; }

DB::Array
parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CollectFunctionParser : public AggregateFunctionParser
throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement");
}

virtual String getCHFunctionName(const DB::DataTypes &) const override
virtual String getCHFunctionName(DB::DataTypes &) const override
{
throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement");
}
Expand Down Expand Up @@ -79,7 +79,7 @@ class CollectListParser : public CollectFunctionParser
static constexpr auto name = "collect_list";
String getName() const override { return name; }
String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupArray"; }
String getCHFunctionName(const DB::DataTypes &) const override { return "groupArray"; }
String getCHFunctionName(DB::DataTypes &) const override { return "groupArray"; }
};

class CollectSetParser : public CollectFunctionParser
Expand All @@ -90,6 +90,6 @@ class CollectSetParser : public CollectFunctionParser
static constexpr auto name = "collect_set";
String getName() const override { return name; }
String getCHFunctionName(const CommonFunctionInfo &) const override { return "groupUniqArray"; }
String getCHFunctionName(const DB::DataTypes &) const override { return "groupUniqArray"; }
String getCHFunctionName(DB::DataTypes &) const override { return "groupUniqArray"; }
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace local_engine
String getName() const override { return #substait_name; } \
static constexpr auto name = #substait_name; \
String getCHFunctionName(const CommonFunctionInfo &) const override { return #ch_name; } \
String getCHFunctionName(const DB::DataTypes &) const override { return #ch_name; } \
String getCHFunctionName(DB::DataTypes &) const override { return #ch_name; } \
}; \
static const AggregateFunctionParserRegister<AggregateFunctionParser##cls_name> register_##cls_name = AggregateFunctionParserRegister<AggregateFunctionParser##cls_name>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ String CountParser::getCHFunctionName(const CommonFunctionInfo &) const
return "count";
}

String CountParser::getCHFunctionName(const DB::DataTypes &) const
String CountParser::getCHFunctionName(DB::DataTypes &) const
{
return "count";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CountParser : public AggregateFunctionParser
static constexpr auto name = "count";
String getName() const override { return name; }
String getCHFunctionName(const CommonFunctionInfo &) const override;
String getCHFunctionName(const DB::DataTypes &) const override;
String getCHFunctionName(DB::DataTypes &) const override;
DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LeadParser : public AggregateFunctionParser
static constexpr auto name = "lead";
String getName() const override { return name; }
String getCHFunctionName(const CommonFunctionInfo &) const override { return "leadInFrame"; }
String getCHFunctionName(const DB::DataTypes &) const override { return "leadInFrame"; }
String getCHFunctionName(DB::DataTypes &) const override { return "leadInFrame"; }
DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override;
};
Expand All @@ -40,7 +40,7 @@ class LagParser : public AggregateFunctionParser
static constexpr auto name = "lag";
String getName() const override { return name; }
String getCHFunctionName(const CommonFunctionInfo &) const override { return "lagInFrame"; }
String getCHFunctionName(const DB::DataTypes &) const override { return "lagInFrame"; }
String getCHFunctionName(DB::DataTypes &) const override { return "lagInFrame"; }
DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const override;
};
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/compile_commands.json

This file was deleted.

0 comments on commit 2a92e73

Please sign in to comment.