Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Mar 4, 2024
1 parent d5be6fa commit dbcd029
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,70 +37,83 @@ namespace ErrorCodes
namespace local_engine
{

void ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const
{
if (size != expect)
throw Exception(
DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} in phase {} requires exactly {} arguments but got {} arguments",
getName(),
magic_enum::enum_name(phase),
expect,
size);
}

const substrait::Expression::Literal &
ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const
{
if (!expr.has_literal())
throw Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"The argument of function {} in phase {} must be literal, but is {}",
getName(),
magic_enum::enum_name(phase),
expr.DebugString());
return expr.literal();
}

String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo & func_info) const
{
const auto & output_type = func_info.output_type;
return output_type.has_list() ? "quantilesGK" : "quantileGK";
}

String ApproxPercentileParser::getCHFunctionName(const DB::DataTypes & types) const
{
/// Always invoked during second stage
assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), 1);

auto type = removeNullable(types[0]);
return isArray(type) ? "quantilesGK" : "quantileGK";
}

DB::Array ApproxPercentileParser::parseFunctionParameters(
const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const
{
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT)
{
Array params;
const auto & arguments = func_info.arguments;
if (arguments.size() != 3)
throw Exception(
DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} in phase {} requires exactly 3 arguments",
getName(),
magic_enum::enum_name(func_info.phase));
assertArgumentsSize(func_info.phase, arguments.size(), 3);

const auto & accuracy_expr = arguments[2].value();
if (accuracy_expr.has_literal())
const auto & accuracy_literal = assertAndGetLiteral(func_info.phase, accuracy_expr);
auto [type1, field1] = parseLiteral(accuracy_literal);
params.emplace_back(std::move(field1));

const auto & percentage_expr = arguments[1].value();
const auto & percentage_literal = assertAndGetLiteral(func_info.phase, percentage_expr);
auto [type2, field2] = parseLiteral(percentage_literal);
if (isArray(type2))
{
auto [type, field] = parseLiteral(accuracy_expr.literal());
params.emplace_back(std::move(field));
/// Multiple percentages for quantilesGK
const Array & percentags = field2.get<Array>();
for (const auto & percentage : percentags)
params.emplace_back(percentage);
}
else
throw Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"The third argument of function {} in phase {} must be literal, but is {}",
getName(),
magic_enum::enum_name(func_info.phase),
accuracy_expr.DebugString());


const auto & arg_percentage = arguments[1].value();
if (arg_percentage.has_literal())
{
auto [type, field] = parseLiteral(arg_percentage.literal());
if (isArray(type))
{
/// Multiple percentages
const Array & percentags = field.get<Array>();
for (const auto & percentage : percentags)
params.emplace_back(percentage);
}
else
params.emplace_back(std::move(field));
/// Single percentage for quantileGK
params.emplace_back(std::move(field2));
}
else
throw Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"The second argument of function {} in phase {} must be literal, but is {}",
getName(),
magic_enum::enum_name(func_info.phase),
arg_percentage.DebugString());

/// Delete percentage and accuracy argument for clickhouse compatiability
arg_nodes.resize(1);
return params;
}
else
{
if (arg_nodes.size() != 1)
throw Exception(
DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} in phase {} requires exactly 1 arguments",
getName(),
magic_enum::enum_name(func_info.phase));

assertArgumentsSize(func_info.phase, arg_nodes.size(), 1);
const auto & result_type = arg_nodes[0]->result_type;
const auto * aggregate_function_type = DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(result_type.get());
if (!aggregate_function_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <Parser/AggregateFunctionParser.h>



/*
spark: approx_percentile(col, percentage [, accuracy])
1. When percentage is an array literal, spark returns an array of percentiles, corresponding to CH: quantilesGK(accuracy, percentage[0], ...)(col)
Expand All @@ -34,12 +33,16 @@ class ApproxPercentileParser : public AggregateFunctionParser
~ApproxPercentileParser() override = default;
String getName() const override { return name; }
static constexpr auto name = "approx_percentile";
String getCHFunctionName(const CommonFunctionInfo &) const override { return "quantileGK"; }
String getCHFunctionName(const DB ::DataTypes &) const override { return "quantileGK"; }
String getCHFunctionName(const CommonFunctionInfo & func_info) const override;
String getCHFunctionName(const DB::DataTypes & types) const override;

DB::Array
parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override;

DB::Array getDefaultFunctionParameters() const override;

private:
void assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const;
const substrait::Expression::Literal & assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const;
};
}

0 comments on commit dbcd029

Please sign in to comment.