Skip to content

Commit

Permalink
fix avg
Browse files Browse the repository at this point in the history
  • Loading branch information
jacktengg committed Oct 16, 2023
1 parent 0eb7bb9 commit 223105a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
11 changes: 11 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@ struct Avg {
template <typename T>
using AggregateFuncAvg = typename Avg<T>::Function;

template <typename T>
struct AvgDecimal256 {
using FieldType = typename AvgNearestFieldTypeTrait<T, true>::Type;
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>;
};

template <typename T>
using AggregateFuncAvgDecimal256 = typename Avg<T>::Function;

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("avg", creator_with_type::creator<AggregateFuncAvg>);
factory.register_function_both("avg_decimal256",
creator_with_type::creator<AggregateFuncAvgDecimal256>);
}
} // namespace doris::vectorized
12 changes: 11 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ namespace doris::vectorized {

template <typename T>
struct AggregateFunctionAvgData {
using ResultType = T;
T sum {};
UInt64 count = 0;

Expand Down Expand Up @@ -107,15 +108,24 @@ template <typename T, typename Data>
class AggregateFunctionAvg final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>> {
public:
/*
using ResultType = DisposeDecimal<T, Float64>;
using ResultDataType =
std::conditional_t<IsDecimalV2<T>, DataTypeDecimal<Decimal128>,
std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<Decimal128I>,
DataTypeNumber<Float64>>>;
*/
using ResultType = std::conditional_t<
IsDecimalV2<T>, Decimal128,
std::conditional_t<IsDecimalNumber<T>, typename Data::ResultType, Float64>>;
using ResultDataType = std::conditional_t<
IsDecimalV2<T>, DataTypeDecimal<Decimal128>,
std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<typename Data::ResultType>,
DataTypeNumber<Float64>>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
using ColVecResult =
std::conditional_t<IsDecimalV2<T>, ColumnDecimal<Decimal128>,
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128I>,
std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<typename Data::ResultType>,
ColumnVector<Float64>>>;

/// ctor for native types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class AggregateFunctionSimpleFactory {

std::string name_str = name;
if (enable_decima256) {
if (name_str == "sum") {
if (name_str == "sum" || name_str == "avg") {
name_str += "_decimal256";
}
}
Expand Down
17 changes: 16 additions & 1 deletion be/src/vec/core/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct NearestFieldTypeImpl {
template <typename T>
using NearestFieldType = typename NearestFieldTypeImpl<T>::Type;

template <typename T>
template <typename T, bool decimal256 = false>
struct AvgNearestFieldTypeTrait {
using Type = typename NearestFieldTypeImpl<T>::Type;
};
Expand Down Expand Up @@ -100,6 +100,21 @@ struct AvgNearestFieldTypeTrait<Int64> {
using Type = double;
};

template <>
struct AvgNearestFieldTypeTrait<Decimal32, true> {
using Type = Decimal256;
};

template <>
struct AvgNearestFieldTypeTrait<Decimal64, true> {
using Type = Decimal256;
};

template <>
struct AvgNearestFieldTypeTrait<Decimal128I, true> {
using Type = Decimal256;
};

class Field;

using FieldVector = std::vector<Field>;
Expand Down

0 comments on commit 223105a

Please sign in to comment.