Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
zclllyybb committed Nov 13, 2024
1 parent 0aba1ed commit ee0de9b
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 221 deletions.
25 changes: 10 additions & 15 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,15 @@

namespace doris::vectorized {

template <template <typename, bool> class AggregateFunctionTemplate,
template <typename> class NameData, template <typename, typename> class Data,
bool is_nullable = false>
template <template <typename> class Function, template <typename> class Data>
AggregateFunctionPtr create_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable,
bool custom_nullable) {
const bool result_is_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create< \
AggregateFunctionTemplate<NameData<Data<TYPE, BaseData<TYPE>>>, is_nullable>>( \
custom_nullable ? remove_nullable(argument_types) : argument_types, \
result_is_nullable);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<Function<Data<TYPE>>>(argument_types, \
result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

Expand All @@ -55,16 +50,16 @@ AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string
const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
return create_function_single_value<AggregateFunctionSampCovariance, SampData>(
name, argument_types, result_is_nullable);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionPop, CovarName, PopData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
return create_function_single_value<AggregateFunctionSampCovariance, PopData>(
name, argument_types, result_is_nullable);
}

void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory) {
Expand Down
117 changes: 22 additions & 95 deletions be/src/vec/aggregate_functions/aggregate_function_covar.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,35 @@

#include <glog/logging.h>

#include "agent/be_exec_version_manager.h"
#define POP true
#define NOTPOP false
#define NULLABLE true
#define NOTNULLABLE false

#include <stddef.h>
#include <stdint.h>

#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <type_traits>

#include "olap/olap_common.h"
#include "runtime/decimalv2_value.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"

namespace doris {
namespace vectorized {
namespace doris::vectorized {

class Arena;
class BufferReadable;
class BufferWritable;
template <typename T>
class ColumnDecimal;
template <typename>
class ColumnVector;
} // namespace vectorized
} // namespace doris

namespace doris::vectorized {

template <typename T>
struct BaseData {
BaseData() : sum_x(0.0), sum_y(0.0), sum_xy(0.0), count(0) {}
BaseData() = default;
virtual ~BaseData() = default;
static DataTypePtr get_return_type() { return std::make_shared<DataTypeFloat64>(); }

void write(BufferWritable& buf) const {
write_binary(sum_x, buf);
Expand Down Expand Up @@ -122,23 +106,26 @@ struct BaseData {
count += 1;
}

double sum_x;
double sum_y;
double sum_xy;
int64_t count;
double sum_x {};
double sum_y {};
double sum_xy {};
int64_t count {};
};

template <typename T, typename Data>
struct PopData : Data {
template <typename T>
struct PopData : BaseData<T> {
static const char* name() { return "covar"; }

void insert_result_into(IColumn& to) const {
auto& col = assert_cast<ColumnFloat64&>(to);
col.get_data().push_back(this->get_pop_result());
}
static DataTypePtr get_return_type() { return std::make_shared<DataTypeNumber<Float64>>(); }
};

template <typename T, typename Data>
struct SampData : Data {
template <typename T>
struct SampData : BaseData<T> {
static const char* name() { return "covar_samp"; }

void insert_result_into(IColumn& to) const {
auto& col = assert_cast<ColumnFloat64&>(to);
if (this->count == 1 || this->count == 0) {
Expand All @@ -147,27 +134,14 @@ struct SampData : Data {
col.get_data().push_back(this->get_samp_result());
}
}
static DataTypePtr get_return_type() { return std::make_shared<DataTypeNumber<Float64>>(); }
};

template <typename Data>
struct CovarName : Data {
static const char* name() { return "covar"; }
};

template <typename Data>
struct CovarSampName : Data {
static const char* name() { return "covar_samp"; }
};

template <bool is_pop, typename Data, bool is_nullable>
class AggregateFunctionSampCovariance
: public IAggregateFunctionDataHelper<
Data, AggregateFunctionSampCovariance<is_pop, Data, is_nullable>> {
: public IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>> {
public:
AggregateFunctionSampCovariance(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
Data, AggregateFunctionSampCovariance<is_pop, Data, is_nullable>>(
: IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>>(
argument_types_) {}

String get_name() const override { return Data::name(); }
Expand All @@ -176,39 +150,7 @@ class AggregateFunctionSampCovariance

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (is_pop) {
this->data(place).add(columns[0], columns[1], row_num);
} else {
if constexpr (is_nullable) { //this if check could remove with old function
// nullable means at least one child is null.
// so here, maybe JUST ONE OF ups is null. so nullptr perhaps in ..._x or ..._y!
const auto* nullable_column_x = check_and_get_column<ColumnNullable>(columns[0]);
const auto* nullable_column_y = check_and_get_column<ColumnNullable>(columns[1]);

if (nullable_column_x && nullable_column_y) { // both nullable
if (!nullable_column_x->is_null_at(row_num) &&
!nullable_column_y->is_null_at(row_num)) {
this->data(place).add(&nullable_column_x->get_nested_column(),
&nullable_column_y->get_nested_column(), row_num);
}
} else if (nullable_column_x) { // x nullable
if (!nullable_column_x->is_null_at(row_num)) {
this->data(place).add(&nullable_column_x->get_nested_column(), columns[1],
row_num);
}
} else if (nullable_column_y) { // y nullable
if (!nullable_column_y->is_null_at(row_num)) {
this->data(place).add(columns[0], &nullable_column_y->get_nested_column(),
row_num);
}
} else {
throw Exception(ErrorCode::INTERNAL_ERROR,
"Nullable function {} get non-nullable columns!", get_name());
}
} else {
this->data(place).add(columns[0], columns[1], row_num);
}
}
this->data(place).add(columns[0], columns[1], row_num);
}

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
Expand All @@ -232,19 +174,4 @@ class AggregateFunctionSampCovariance
}
};

template <typename Data, bool is_nullable>
class AggregateFunctionSamp final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
public:
AggregateFunctionSamp(const DataTypes& argument_types_)
: AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable>(argument_types_) {}
};

template <typename Data, bool is_nullable>
class AggregateFunctionPop final : public AggregateFunctionSampCovariance<POP, Data, is_nullable> {
public:
AggregateFunctionPop(const DataTypes& argument_types_)
: AggregateFunctionSampCovariance<POP, Data, is_nullable>(argument_types_) {}
};

} // namespace doris::vectorized
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

#pragma once

#include "factory_helpers.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_map.h"
#include "vec/columns/column_nullable.h"
Expand All @@ -28,10 +26,8 @@
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
#include "vec/functions/function.h"
#include "vec/io/io_helper.h"

namespace doris::vectorized {

Expand All @@ -51,7 +47,7 @@ struct Value {

void insert_into(IColumn& to) const {
if constexpr (arg_is_nullable) {
auto* col = assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(_ptr);
const auto* col = assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(_ptr);
assert_cast<ColVecType&, TypeCheckOnRelease::DISABLE>(to).insert_from(
col->get_nested_column(), _offset);
} else {
Expand Down Expand Up @@ -89,7 +85,8 @@ struct CopiedValue : public Value<ColVecType, arg_is_nullable> {
// because the address have meaningless, only need it to check is nullptr
this->_ptr = (IColumn*)0x00000001;
if constexpr (arg_is_nullable) {
auto* col = assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(column);
const auto* col =
assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(column);
if (col->is_null_at(row)) {
this->reset();
return;
Expand Down Expand Up @@ -149,8 +146,9 @@ struct ReaderFirstAndLastData {
bool _has_value = false;
};

template <typename Data>
struct ReaderFunctionFirstData : Data {
template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
struct ReaderFunctionFirstData
: ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, is_copy> {
void add(int64_t row, const IColumn** columns) {
if (this->has_set_value()) {
return;
Expand All @@ -160,13 +158,15 @@ struct ReaderFunctionFirstData : Data {
static const char* name() { return "first_value"; }
};

template <typename Data>
struct ReaderFunctionFirstNonNullData : Data {
template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
struct ReaderFunctionFirstNonNullData
: ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, is_copy> {
void add(int64_t row, const IColumn** columns) {
if (this->has_set_value()) {
return;
}
if constexpr (Data::nullable) {
if constexpr (ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable,
is_copy>::nullable) {
const auto* nullable_column =
assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(columns[0]);
if (nullable_column->is_null_at(row)) {
Expand All @@ -178,16 +178,19 @@ struct ReaderFunctionFirstNonNullData : Data {
static const char* name() { return "first_non_null_value"; }
};

template <typename Data>
struct ReaderFunctionLastData : Data {
template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
struct ReaderFunctionLastData
: ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, is_copy> {
void add(int64_t row, const IColumn** columns) { this->set_value(columns, row); }
static const char* name() { return "last_value"; }
};

template <typename Data>
struct ReaderFunctionLastNonNullData : Data {
template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
struct ReaderFunctionLastNonNullData
: ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, is_copy> {
void add(int64_t row, const IColumn** columns) {
if constexpr (Data::nullable) {
if constexpr (ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable,
is_copy>::nullable) {
const auto* nullable_column =
assert_cast<const ColumnNullable*, TypeCheckOnRelease::DISABLE>(columns[0]);
if (nullable_column->is_null_at(row)) {
Expand Down Expand Up @@ -256,17 +259,18 @@ class ReaderFunctionData final
DataTypePtr _argument_type;
};

template <template <typename> class AggregateFunctionTemplate, template <typename> class Impl,
bool result_is_nullable, bool arg_is_nullable, bool is_copy = false>
template <template <typename, bool, bool, bool> class FunctionData, bool result_is_nullable,
bool arg_is_nullable, bool is_copy = false>
AggregateFunctionPtr create_function_single_value(const String& name,
const DataTypes& argument_types) {
auto type = remove_nullable(argument_types[0]);
WhichDataType which(*type);

#define DISPATCH(TYPE, COLUMN_TYPE) \
if (which.idx == TypeIndex::TYPE) \
return std::make_shared<AggregateFunctionTemplate<Impl<ReaderFirstAndLastData< \
COLUMN_TYPE, result_is_nullable, arg_is_nullable, is_copy>>>>(argument_types);
#define DISPATCH(TYPE, COLUMN_TYPE) \
if (which.idx == TypeIndex::TYPE) \
return std::make_shared<ReaderFunctionData< \
FunctionData<COLUMN_TYPE, result_is_nullable, arg_is_nullable, is_copy>>>( \
argument_types);
TYPE_TO_COLUMN_TYPE(DISPATCH)
#undef DISPATCH

Expand All @@ -285,9 +289,9 @@ AggregateFunctionPtr create_function_single_value(const String& name,
std::visit( \
[&](auto result_is_nullable, auto arg_is_nullable) { \
res = AggregateFunctionPtr( \
create_function_single_value<ReaderFunctionData, FUNCTION_DATA, \
result_is_nullable, arg_is_nullable, \
is_copy>(name, argument_types)); \
create_function_single_value<FUNCTION_DATA, result_is_nullable, \
arg_is_nullable, is_copy>( \
name, argument_types)); \
}, \
make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable)); \
if (!res) { \
Expand All @@ -303,4 +307,6 @@ CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first_non_nu
CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last, ReaderFunctionLastData);
CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last_non_null_value,
ReaderFunctionLastNonNullData);
#undef CREATE_READER_FUNCTION_WITH_NAME_AND_DATA

} // namespace doris::vectorized
Loading

0 comments on commit ee0de9b

Please sign in to comment.