From 20deb8e9ca3f5412d623d3769ca3582890754c1e Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Sat, 2 Mar 2024 20:23:09 +0800 Subject: [PATCH] fix bugs --- .../ApproxPercentileParser.cpp | 99 +++++++++++-------- .../ApproxPercentileParser.h | 9 +- 2 files changed, 62 insertions(+), 46 deletions(-) diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp index 7fde191a2cdc..1ea490c69c46 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp @@ -37,6 +37,46 @@ namespace ErrorCodes namespace local_engine { +void ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const +{ + if (size != expect) + throw Exception( + DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Function {} in phase {} requires exactly {} arguments but got {} arguments", + getName(), + magic_enum::enum_name(phase), + expect, + size); +} + +const substrait::Expression::Literal & +ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const +{ + if (!expr.has_literal()) + throw Exception( + DB::ErrorCodes::BAD_ARGUMENTS, + "The argument of function {} in phase {} must be literal, but is {}", + getName(), + magic_enum::enum_name(phase), + expr.DebugString()); + return expr.literal(); +} + +String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo & func_info) const +{ + const auto & output_type = func_info.output_type; + return output_type.has_list() ? "quantilesGK" : "quantileGK"; +} + +String ApproxPercentileParser::getCHFunctionName(const DB::DataTypes & types) const +{ + /// Always invoked during second stage + assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT, types.size(), 1); + + auto type = removeNullable(types[0]); + return isArray(type) ? "quantilesGK" : "quantileGK"; +} + DB::Array ApproxPercentileParser::parseFunctionParameters( const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const { @@ -44,49 +84,28 @@ DB::Array ApproxPercentileParser::parseFunctionParameters( { Array params; const auto & arguments = func_info.arguments; - if (arguments.size() != 3) - throw Exception( - DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Function {} in phase {} requires exactly 3 arguments", - getName(), - magic_enum::enum_name(func_info.phase)); + assertArgumentsSize(func_info.phase, arguments.size(), 3); const auto & accuracy_expr = arguments[2].value(); - if (accuracy_expr.has_literal()) + const auto & accuracy_literal = assertAndGetLiteral(func_info.phase, accuracy_expr); + auto [type1, field1] = parseLiteral(accuracy_literal); + params.emplace_back(std::move(field1)); + + const auto & percentage_expr = arguments[1].value(); + const auto & percentage_literal = assertAndGetLiteral(func_info.phase, percentage_expr); + auto [type2, field2] = parseLiteral(percentage_literal); + if (isArray(type2)) { - auto [type, field] = parseLiteral(accuracy_expr.literal()); - params.emplace_back(std::move(field)); + /// Multiple percentages for quantilesGK + const Array & percentags = field2.get(); + for (const auto & percentage : percentags) + params.emplace_back(percentage); } else - throw Exception( - DB::ErrorCodes::BAD_ARGUMENTS, - "The third argument of function {} in phase {} must be literal, but is {}", - getName(), - magic_enum::enum_name(func_info.phase), - accuracy_expr.DebugString()); - - - const auto & arg_percentage = arguments[1].value(); - if (arg_percentage.has_literal()) { - auto [type, field] = parseLiteral(arg_percentage.literal()); - if (isArray(type)) - { - /// Multiple percentages - const Array & percentags = field.get(); - for (const auto & percentage : percentags) - params.emplace_back(percentage); - } - else - params.emplace_back(std::move(field)); + /// Single percentage for quantileGK + params.emplace_back(std::move(field2)); } - else - throw Exception( - DB::ErrorCodes::BAD_ARGUMENTS, - "The second argument of function {} in phase {} must be literal, but is {}", - getName(), - magic_enum::enum_name(func_info.phase), - arg_percentage.DebugString()); /// Delete percentage and accuracy argument for clickhouse compatiability arg_nodes.resize(1); @@ -94,13 +113,7 @@ DB::Array ApproxPercentileParser::parseFunctionParameters( } else { - if (arg_nodes.size() != 1) - throw Exception( - DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Function {} in phase {} requires exactly 1 arguments", - getName(), - magic_enum::enum_name(func_info.phase)); - + assertArgumentsSize(func_info.phase, arg_nodes.size(), 1); const auto & result_type = arg_nodes[0]->result_type; const auto * aggregate_function_type = DB::checkAndGetDataType(result_type.get()); if (!aggregate_function_type) diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h index e9cc4ca359fe..bd1019bac499 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h @@ -18,7 +18,6 @@ #include - /* spark: approx_percentile(col, percentage [, accuracy]) 1. When percentage is an array literal, spark returns an array of percentiles, corresponding to CH: quantilesGK(accuracy, percentage[0], ...)(col) @@ -34,12 +33,16 @@ class ApproxPercentileParser : public AggregateFunctionParser ~ApproxPercentileParser() override = default; String getName() const override { return name; } static constexpr auto name = "approx_percentile"; - String getCHFunctionName(const CommonFunctionInfo &) const override { return "quantileGK"; } - String getCHFunctionName(const DB ::DataTypes &) const override { return "quantileGK"; } + String getCHFunctionName(const CommonFunctionInfo & func_info) const override; + String getCHFunctionName(const DB::DataTypes & types) const override; DB::Array parseFunctionParameters(const CommonFunctionInfo & /*func_info*/, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override; DB::Array getDefaultFunctionParameters() const override; + +private: + void assertArgumentsSize(substrait::AggregationPhase phase, size_t size, size_t expect) const; + const substrait::Expression::Literal & assertAndGetLiteral(substrait::AggregationPhase phase, const substrait::Expression & expr) const; }; }