From e7dc94b8a65e454710d025ab96e0b3a5e827b40e Mon Sep 17 00:00:00 2001 From: jacktengg <18241664+jacktengg@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:29:11 +0800 Subject: [PATCH] [feature](decimal) support decimal256 --- be/src/common/consts.h | 12 + be/src/exec/olap_common.h | 21 +- be/src/exec/schema_scanner.cpp | 6 + .../schema_scanner/schema_columns_scanner.cpp | 4 +- be/src/exec/table_connector.cpp | 3 +- be/src/exec/text_converter.cpp | 13 + be/src/exprs/create_predicate_function.h | 12 +- be/src/exprs/runtime_filter.cpp | 4 + be/src/gutil/endian.h | 12 + be/src/olap/delete_handler.cpp | 2 + be/src/olap/field.h | 4 + be/src/olap/in_list_predicate.h | 2 +- be/src/olap/key_coder.cpp | 1 + be/src/olap/key_coder.h | 33 +- be/src/olap/olap_common.h | 4 +- be/src/olap/predicate_creator.h | 7 +- .../rowset/segment_v2/bitmap_index_writer.cpp | 3 + .../olap/rowset/segment_v2/bitshuffle_page.h | 1 + .../segment_v2/bloom_filter_index_writer.cpp | 1 + .../olap/rowset/segment_v2/encoding_info.cpp | 4 + .../segment_v2/inverted_index_writer.cpp | 6 + .../olap/rowset/segment_v2/zone_map_index.cpp | 3 +- be/src/olap/schema.cpp | 3 + be/src/olap/tablet_schema.cpp | 7 + be/src/olap/types.cpp | 1 + be/src/olap/types.h | 32 + be/src/olap/utils.h | 1 + be/src/pipeline/exec/scan_operator.cpp | 4 +- be/src/runtime/decimalv2_value.h | 7 + be/src/runtime/define_primitive_type.h | 3 +- be/src/runtime/fold_constant_executor.cpp | 3 +- be/src/runtime/primitive_type.cpp | 12 + be/src/runtime/primitive_type.h | 6 + be/src/runtime/raw_value.h | 2 + be/src/runtime/runtime_predicate.cpp | 4 + be/src/runtime/runtime_predicate.h | 6 + be/src/runtime/runtime_state.h | 4 + be/src/runtime/type_limit.h | 8 + be/src/runtime/types.cpp | 12 +- be/src/runtime/types.h | 38 +- be/src/util/binary_cast.hpp | 6 +- be/src/util/string_parser.hpp | 5 +- .../aggregate_function_avg.cpp | 11 + .../aggregate_function_avg.h | 24 +- .../aggregate_function_product.h | 2 +- .../aggregate_function_simple_factory.h | 8 +- .../aggregate_function_sum.cpp | 2 + .../aggregate_function_sum.h | 16 +- be/src/vec/aggregate_functions/helpers.h | 3 +- be/src/vec/columns/column_array.h | 7 +- be/src/vec/columns/column_decimal.cpp | 11 +- be/src/vec/columns/column_decimal.h | 2 +- be/src/vec/columns/columns_number.h | 1 + be/src/vec/common/arithmetic_overflow.h | 25 + be/src/vec/common/field_visitors.h | 2 + be/src/vec/common/hash_table/hash.h | 21 + be/src/vec/common/int_exp.h | 92 ++ be/src/vec/core/accurate_comparison.h | 111 ++ be/src/vec/core/call_on_type_index.h | 9 + be/src/vec/core/decimal_comparison.h | 31 +- be/src/vec/core/decomposed_float.h | 201 +++ be/src/vec/core/extended_types.h | 127 ++ be/src/vec/core/field.cpp | 1 + be/src/vec/core/field.h | 60 +- be/src/vec/core/types.h | 423 ++++- be/src/vec/core/wide_integer.h | 302 ++++ be/src/vec/core/wide_integer_impl.h | 1389 +++++++++++++++++ be/src/vec/core/wide_integer_to_string.h | 77 + .../vec/data_types/convert_field_to_type.cpp | 3 + be/src/vec/data_types/data_type.cpp | 2 + be/src/vec/data_types/data_type.h | 4 +- be/src/vec/data_types/data_type_decimal.cpp | 15 +- be/src/vec/data_types/data_type_decimal.h | 68 +- be/src/vec/data_types/data_type_factory.cpp | 10 + be/src/vec/data_types/get_least_supertype.cpp | 15 +- be/src/vec/data_types/number_traits.h | 9 + .../serde/data_type_decimal_serde.cpp | 2 + .../serde/data_type_decimal_serde.h | 12 +- .../parquet/byte_array_dict_decoder.cpp | 1 + .../parquet/byte_array_plain_decoder.cpp | 1 + .../parquet/fix_length_dict_decoder.hpp | 2 + .../parquet/fix_length_plain_decoder.cpp | 1 + be/src/vec/exec/jni_connector.cpp | 1 + be/src/vec/exec/scan/vscan_node.cpp | 4 +- be/src/vec/exec/vjdbc_connector.cpp | 3 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 2 +- be/src/vec/exprs/vexpr.cpp | 5 + be/src/vec/exprs/vexpr.h | 8 + .../array/function_array_aggregation.cpp | 1 + .../functions/array/function_array_apply.cpp | 2 + .../array/function_array_difference.h | 3 + .../functions/array/function_array_distinct.h | 3 + .../functions/array/function_array_element.h | 3 + .../array/function_array_enumerate_uniq.cpp | 2 + .../functions/array/function_array_index.h | 4 + .../vec/functions/array/function_array_join.h | 3 + .../functions/array/function_array_remove.h | 3 + .../functions/array/function_arrays_overlap.h | 4 + be/src/vec/functions/function.h | 11 +- .../functions/function_binary_arithmetic.h | 63 +- be/src/vec/functions/function_cast.h | 24 +- .../vec/functions/function_multi_same_args.h | 1 - be/src/vec/functions/function_string.h | 23 +- .../vec/functions/function_unary_arithmetic.h | 4 +- .../vec/functions/function_width_bucket.cpp | 3 + be/src/vec/functions/functions_comparison.h | 1 - be/src/vec/functions/if.cpp | 1 - be/src/vec/functions/least_greast.cpp | 6 +- be/src/vec/olap/olap_data_convertor.cpp | 3 + be/src/vec/sink/vtablet_block_convertor.cpp | 12 +- be/src/vec/sink/vtablet_block_convertor.h | 2 + .../vec/sink/writer/vmysql_table_writer.cpp | 3 +- be/test/vec/data_types/decimal_test.cpp | 81 + .../apache/doris/catalog/PrimitiveType.java | 11 +- .../org/apache/doris/catalog/ScalarType.java | 48 +- .../java/org/apache/doris/catalog/Type.java | 64 +- .../org/apache/doris/analysis/CastExpr.java | 1 + .../org/apache/doris/analysis/ColumnDef.java | 1 + .../java/org/apache/doris/analysis/Expr.java | 2 + .../apache/doris/analysis/LiteralExpr.java | 1 + .../apache/doris/analysis/StringLiteral.java | 1 + .../org/apache/doris/analysis/TypeDef.java | 29 + .../apache/doris/catalog/AliasFunction.java | 1 + .../java/org/apache/doris/catalog/Column.java | 1 + .../org/apache/doris/common/util/Util.java | 1 + .../apache/doris/mysql/MysqlSerializer.java | 4 +- .../exceptions/NotSupportedException.java | 28 + .../rules/FoldConstantRuleOnBE.java | 3 +- .../nereids/trees/expressions/Divide.java | 4 +- .../nereids/trees/expressions/Multiply.java | 8 +- .../functions/ComputePrecisionForSum.java | 7 +- .../trees/expressions/functions/agg/Avg.java | 17 +- .../doris/nereids/types/DecimalV3Type.java | 35 +- .../org/apache/doris/qe/ConnectProcessor.java | 12 +- .../org/apache/doris/qe/SessionVariable.java | 12 + .../apache/doris/qe/cache/PartitionRange.java | 1 + .../doris/rewrite/FoldConstantsRule.java | 3 +- gensrc/proto/internal_service.proto | 1 + gensrc/proto/types.proto | 1 + gensrc/thrift/PaloInternalService.thrift | 2 + gensrc/thrift/Types.thrift | 3 +- .../decimalv3/test_arithmetic_expressions.out | 141 +- .../datatype_p0/decimalv3/test_decimalv3.out | 6 + .../datatype_p0/decimalv3/test_predicate.out | 77 + .../aggregate/aggregate_decimal256.out | 97 ++ .../query_p0/join/test_join_decimal256.out | 41 + .../test_arithmetic_expressions.groovy | 275 +++- .../decimalv3/test_decimalv3.groovy | 49 +- .../decimalv3/test_decimalv3_overflow.groovy | 1 + .../decimalv3/test_predicate.groovy | 52 + .../aggregate/aggregate_decimal256.groovy | 151 ++ .../query_p0/join/test_join_decimal256.groovy | 97 ++ 152 files changed, 4704 insertions(+), 299 deletions(-) create mode 100644 be/src/vec/core/decomposed_float.h create mode 100644 be/src/vec/core/extended_types.h create mode 100644 be/src/vec/core/wide_integer.h create mode 100644 be/src/vec/core/wide_integer_impl.h create mode 100644 be/src/vec/core/wide_integer_to_string.h create mode 100644 be/test/vec/data_types/decimal_test.cpp create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/exceptions/NotSupportedException.java create mode 100644 regression-test/data/query_p0/aggregate/aggregate_decimal256.out create mode 100644 regression-test/data/query_p0/join/test_join_decimal256.out create mode 100644 regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy create mode 100644 regression-test/suites/query_p0/join/test_join_decimal256.groovy diff --git a/be/src/common/consts.h b/be/src/common/consts.h index 72942f75b22c995..7548f9a20269008 100644 --- a/be/src/common/consts.h +++ b/be/src/common/consts.h @@ -30,8 +30,20 @@ const std::string ROWID_COL = "__DORIS_ROWID_COL__"; const std::string ROW_STORE_COL = "__DORIS_ROW_STORE_COL__"; const std::string DYNAMIC_COLUMN_NAME = "__DORIS_DYNAMIC_COL__"; +/// The maximum precision representable by a 4-byte decimal (Decimal4Value) constexpr int MAX_DECIMAL32_PRECISION = 9; +/// The maximum precision representable by a 8-byte decimal (Decimal8Value) constexpr int MAX_DECIMAL64_PRECISION = 18; +/// The maximum precision representable by a 16-byte decimal constexpr int MAX_DECIMAL128_PRECISION = 38; +/// The maximum precision representable by a 32-byte decimal +constexpr int MAX_DECIMAL256_PRECISION = 76; + +/// Must be kept in sync with FE's max precision/scale. +static constexpr int MAX_DECIMALV2_PRECISION = MAX_DECIMAL128_PRECISION; +static constexpr int MAX_DECIMALV2_SCALE = MAX_DECIMALV2_PRECISION; + +static constexpr int MAX_DECIMALV3_PRECISION = MAX_DECIMAL256_PRECISION; +static constexpr int MAX_DECIMALV3_SCALE = MAX_DECIMALV3_PRECISION; } // namespace BeConsts } // namespace doris diff --git a/be/src/exec/olap_common.h b/be/src/exec/olap_common.h index 6ef0edd7a304177..57d82f227a46ed5 100644 --- a/be/src/exec/olap_common.h +++ b/be/src/exec/olap_common.h @@ -55,6 +55,8 @@ std::string cast_to_string(T value, int scale) { return ((vectorized::Decimal)value).to_string(scale); } else if constexpr (primitive_type == TYPE_DECIMAL128I) { return ((vectorized::Decimal)value).to_string(scale); + } else if constexpr (primitive_type == TYPE_DECIMAL256) { + return ((vectorized::Decimal)value).to_string(scale); } else if constexpr (primitive_type == TYPE_TINYINT) { return std::to_string(static_cast(value)); } else if constexpr (primitive_type == TYPE_LARGEINT) { @@ -503,16 +505,15 @@ class OlapScanKeys { bool _is_convertible; }; -using ColumnValueRangeType = - std::variant, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange, - ColumnValueRange, ColumnValueRange>; +using ColumnValueRangeType = std::variant< + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange, + ColumnValueRange, ColumnValueRange>; template const typename ColumnValueRange::CppType diff --git a/be/src/exec/schema_scanner.cpp b/be/src/exec/schema_scanner.cpp index 9733558284a8fda..3062b9d7be2cdc4 100644 --- a/be/src/exec/schema_scanner.cpp +++ b/be/src/exec/schema_scanner.cpp @@ -287,6 +287,12 @@ Status SchemaScanner::fill_dest_column_for_range(vectorized::Block* block, size_ reinterpret_cast(&num), 0); break; } + // case TYPE_DECIMAL256: { + // const vectorized::Int256 num = (reinterpret_cast(data))->value; + // reinterpret_cast(col_ptr)->insert_data( + // reinterpret_cast(&num), 0); + // break; + // } case TYPE_DECIMAL32: { const int32_t num = *reinterpret_cast(data); diff --git a/be/src/exec/schema_scanner/schema_columns_scanner.cpp b/be/src/exec/schema_scanner/schema_columns_scanner.cpp index 0c728643f770643..7fa8963ffa581e1 100644 --- a/be/src/exec/schema_scanner/schema_columns_scanner.cpp +++ b/be/src/exec/schema_scanner/schema_columns_scanner.cpp @@ -138,6 +138,7 @@ std::string SchemaColumnsScanner::_to_mysql_data_type_string(TColumnDesc& desc) case TPrimitiveType::DECIMAL32: case TPrimitiveType::DECIMAL64: case TPrimitiveType::DECIMAL128I: + case TPrimitiveType::DECIMAL256: case TPrimitiveType::DECIMALV2: { return "decimal"; } @@ -208,7 +209,8 @@ std::string SchemaColumnsScanner::_type_to_string(TColumnDesc& desc) { } case TPrimitiveType::DECIMAL32: case TPrimitiveType::DECIMAL64: - case TPrimitiveType::DECIMAL128I: { + case TPrimitiveType::DECIMAL128I: + case TPrimitiveType::DECIMAL256: { fmt::memory_buffer debug_string_buffer; fmt::format_to( debug_string_buffer, "decimalv3({}, {})", diff --git a/be/src/exec/table_connector.cpp b/be/src/exec/table_connector.cpp index 7aa90eda08678a5..47e54be1dc62435 100644 --- a/be/src/exec/table_connector.cpp +++ b/be/src/exec/table_connector.cpp @@ -250,7 +250,8 @@ Status TableConnector::convert_column_data(const vectorized::ColumnPtr& column_p } case TYPE_DECIMAL32: case TYPE_DECIMAL64: - case TYPE_DECIMAL128I: { + case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: { auto decimal_type = remove_nullable(type_ptr); auto val = decimal_type->to_string(*column, row); fmt::format_to(_insert_stmt_buffer, "{}", val); diff --git a/be/src/exec/text_converter.cpp b/be/src/exec/text_converter.cpp index 59417bc92651fc8..e14e70054245cfb 100644 --- a/be/src/exec/text_converter.cpp +++ b/be/src/exec/text_converter.cpp @@ -41,6 +41,7 @@ #include "vec/columns/column_struct.h" #include "vec/columns/column_vector.h" #include "vec/core/types.h" +#include "vec/core/wide_integer.h" #include "vec/runtime/vdatetime_value.h" namespace doris { @@ -290,6 +291,18 @@ bool TextConverter::_write_data(const TypeDescriptor& type_desc, .resize_fill(origin_size + rows, value); break; } + case TYPE_DECIMAL256: { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + wide::Int256 value = StringParser::string_to_decimal( + data, len, type_desc.precision, type_desc.scale, &result); + if (result != StringParser::PARSE_SUCCESS) { + parse_result = StringParser::PARSE_FAILURE; + break; + } + reinterpret_cast*>(col_ptr)->get_data().resize_fill( + origin_size + rows, value); + break; + } case TYPE_ARRAY: { auto col = reinterpret_cast(col_ptr); diff --git a/be/src/exprs/create_predicate_function.h b/be/src/exprs/create_predicate_function.h index 7d89141c443b37c..e1c4f1931b8a5ab 100644 --- a/be/src/exprs/create_predicate_function.h +++ b/be/src/exprs/create_predicate_function.h @@ -34,10 +34,11 @@ class MinmaxFunctionTraits { using BasePtr = MinMaxFuncBase*; template static BasePtr get_function() { - return new MinMaxNumFunc::CppType>, - typename PrimitiveTypeTraits::CppType>>(); + return new MinMaxNumFunc< + std::conditional_t::CppType>, + typename PrimitiveTypeTraits::CppType>>(); } }; @@ -106,7 +107,8 @@ class PredicateFunctionCreator { M(TYPE_STRING) \ M(TYPE_DECIMAL32) \ M(TYPE_DECIMAL64) \ - M(TYPE_DECIMAL128I) + M(TYPE_DECIMAL128I) \ + M(TYPE_DECIMAL256) template typename Traits::BasePtr create_predicate_function(PrimitiveType type) { diff --git a/be/src/exprs/runtime_filter.cpp b/be/src/exprs/runtime_filter.cpp index 3679ce4c84c917e..be8ab3b02e10bc5 100644 --- a/be/src/exprs/runtime_filter.cpp +++ b/be/src/exprs/runtime_filter.cpp @@ -99,6 +99,8 @@ PColumnType to_proto(PrimitiveType type) { return PColumnType::COLUMN_TYPE_DECIMAL64; case TYPE_DECIMAL128I: return PColumnType::COLUMN_TYPE_DECIMAL128I; + case TYPE_DECIMAL256: + return PColumnType::COLUMN_TYPE_DECIMAL256; case TYPE_CHAR: return PColumnType::COLUMN_TYPE_CHAR; case TYPE_VARCHAR: @@ -148,6 +150,8 @@ PrimitiveType to_primitive_type(PColumnType type) { return TYPE_DECIMAL64; case PColumnType::COLUMN_TYPE_DECIMAL128I: return TYPE_DECIMAL128I; + case PColumnType::COLUMN_TYPE_DECIMAL256: + return TYPE_DECIMAL256; case PColumnType::COLUMN_TYPE_VARCHAR: return TYPE_VARCHAR; case PColumnType::COLUMN_TYPE_CHAR: diff --git a/be/src/gutil/endian.h b/be/src/gutil/endian.h index 4434bb943b47d86..66d849f73cd5541 100644 --- a/be/src/gutil/endian.h +++ b/be/src/gutil/endian.h @@ -35,6 +35,7 @@ #include "gutil/int128.h" #include "gutil/integral_types.h" #include "gutil/port.h" +#include "vec/core/wide_integer.h" inline uint64 gbswap_64(uint64 host_int) { #if defined(__GNUC__) && defined(__x86_64__) && !defined(__APPLE__) @@ -59,6 +60,11 @@ inline unsigned __int128 gbswap_128(unsigned __int128 host_int) { (static_cast(bswap_64(static_cast(host_int))) << 64); } +inline wide::UInt256 gbswap_256(wide::UInt256 host_int) { + wide::UInt256 result{gbswap_64(host_int.items[0]), gbswap_64(host_int.items[1]), gbswap_64(host_int.items[2]), gbswap_64(host_int.items[3])}; + return result; +} + // Swap bytes of a 24-bit value. inline uint32_t bswap_24(uint32_t x) { return ((x & 0x0000ffULL) << 16) | ((x & 0x00ff00ULL)) | ((x & 0xff0000ULL) >> 16); @@ -252,6 +258,9 @@ class BigEndian { static unsigned __int128 FromHost128(unsigned __int128 x) { return gbswap_128(x); } static unsigned __int128 ToHost128(unsigned __int128 x) { return gbswap_128(x); } + static wide::UInt256 FromHost256(wide::UInt256 x) { return gbswap_256(x); } + static wide::UInt256 ToHost256(wide::UInt256 x) { return gbswap_256(x); } + static bool IsLittleEndian() { return true; } #elif defined IS_BIG_ENDIAN @@ -271,6 +280,9 @@ class BigEndian { static uint128 FromHost128(uint128 x) { return x; } static uint128 ToHost128(uint128 x) { return x; } + static wide::UInt256 FromHost256(wide::UInt256 x) { return x; } + static wide::UInt256 ToHost256(wide::UInt256 x) { return x; } + static bool IsLittleEndian() { return false; } #endif /* ENDIAN */ diff --git a/be/src/olap/delete_handler.cpp b/be/src/olap/delete_handler.cpp index 886d11a8722015f..e52daed96677fca 100644 --- a/be/src/olap/delete_handler.cpp +++ b/be/src/olap/delete_handler.cpp @@ -199,6 +199,8 @@ bool DeleteHandler::is_condition_value_valid(const TabletColumn& column, return valid_decimal(value_str, column.precision(), column.frac()); case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: return valid_decimal(value_str, column.precision(), column.frac()); + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + return valid_decimal(value_str, column.precision(), column.frac()); case FieldType::OLAP_FIELD_TYPE_CHAR: case FieldType::OLAP_FIELD_TYPE_VARCHAR: return value_str.size() <= column.length(); diff --git a/be/src/olap/field.h b/be/src/olap/field.h index fc9c87bfb1c3aec..a92a8f121c149fc 100644 --- a/be/src/olap/field.h +++ b/be/src/olap/field.h @@ -589,6 +589,8 @@ class FieldFactory { [[fallthrough]]; case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: [[fallthrough]]; + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + [[fallthrough]]; case FieldType::OLAP_FIELD_TYPE_DATETIMEV2: { Field* field = new Field(column); field->set_precision(column.precision()); @@ -647,6 +649,8 @@ class FieldFactory { [[fallthrough]]; case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: [[fallthrough]]; + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + [[fallthrough]]; case FieldType::OLAP_FIELD_TYPE_DATETIMEV2: { Field* field = new Field(column); field->set_precision(column.precision()); diff --git a/be/src/olap/in_list_predicate.h b/be/src/olap/in_list_predicate.h index 329c9b8dc0e269e..60a8be73199b586 100644 --- a/be/src/olap/in_list_predicate.h +++ b/be/src/olap/in_list_predicate.h @@ -97,7 +97,7 @@ class InListPredicateBase : public ColumnPredicate { if constexpr (Type == TYPE_STRING || Type == TYPE_CHAR) { tmp = convert(*col, condition, arena); } else if constexpr (Type == TYPE_DECIMAL32 || Type == TYPE_DECIMAL64 || - Type == TYPE_DECIMAL128I) { + Type == TYPE_DECIMAL128I || Type == TYPE_DECIMAL256) { tmp = convert(*col, condition); } else { tmp = convert(condition); diff --git a/be/src/olap/key_coder.cpp b/be/src/olap/key_coder.cpp index 803b353375d2d9e..168117117d91161 100644 --- a/be/src/olap/key_coder.cpp +++ b/be/src/olap/key_coder.cpp @@ -80,6 +80,7 @@ class KeyCoderResolver { add_mapping(); add_mapping(); add_mapping(); + add_mapping(); } template diff --git a/be/src/olap/key_coder.h b/be/src/olap/key_coder.h index 30d33cd3faec009..6885a0d96f251bc 100644 --- a/be/src/olap/key_coder.h +++ b/be/src/olap/key_coder.h @@ -85,6 +85,7 @@ class KeyCoderTraits< field_type, typename std::enable_if< std::is_integral::CppType>::value || + field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL256 || vectorized::IsDecimalNumber::CppType>>::type> { public: using CppType = typename CppTypeTraits::CppType; @@ -93,20 +94,24 @@ class KeyCoderTraits< private: // Swap value's endian from/to big endian static UnsignedCppType swap_big_endian(UnsignedCppType val) { - switch (sizeof(UnsignedCppType)) { - case 1: - return val; - case 2: - return BigEndian::FromHost16(val); - case 4: - return BigEndian::FromHost32(val); - case 8: - return BigEndian::FromHost64(val); - case 16: - return BigEndian::FromHost128(val); - default: - LOG(FATAL) << "Invalid type to big endian, type=" << int(field_type) - << ", size=" << sizeof(UnsignedCppType); + if constexpr (field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL256) { + return BigEndian::FromHost256(val); + } else { + switch (sizeof(UnsignedCppType)) { + case 1: + return val; + case 2: + return BigEndian::FromHost16(val); + case 4: + return BigEndian::FromHost32(val); + case 8: + return BigEndian::FromHost64(val); + case 16: + return BigEndian::FromHost128(val); + default: + LOG(FATAL) << "Invalid type to big endian, type=" << int(field_type) + << ", size=" << sizeof(UnsignedCppType); + } } } diff --git a/be/src/olap/olap_common.h b/be/src/olap/olap_common.h index 130d65e7ef448da..3811aab378f6d8a 100644 --- a/be/src/olap/olap_common.h +++ b/be/src/olap/olap_common.h @@ -143,7 +143,8 @@ enum class FieldType { OLAP_FIELD_TYPE_DECIMAL128I = 33, OLAP_FIELD_TYPE_JSONB = 34, OLAP_FIELD_TYPE_VARIANT = 35, - OLAP_FIELD_TYPE_AGG_STATE = 36 + OLAP_FIELD_TYPE_AGG_STATE = 36, + OLAP_FIELD_TYPE_DECIMAL256 = 37, }; // Define all aggregation methods supported by Field @@ -197,6 +198,7 @@ constexpr bool field_is_numeric_type(const FieldType& field_type) { field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL32 || field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL64 || field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL128I || + field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL256 || field_type == FieldType::OLAP_FIELD_TYPE_BOOL; } diff --git a/be/src/olap/predicate_creator.h b/be/src/olap/predicate_creator.h index 6298f6f231e481e..dd9fded40eef2a5 100644 --- a/be/src/olap/predicate_creator.h +++ b/be/src/olap/predicate_creator.h @@ -96,8 +96,8 @@ class DecimalPredicateCreator : public PredicateCreator { static CppType convert(const TabletColumn& column, const std::string& condition) { StringParser::ParseResult result = StringParser::ParseResult::PARSE_SUCCESS; // return CppType value cast from int128_t - return StringParser::string_to_decimal(condition.data(), condition.size(), - column.precision(), column.frac(), &result); + return CppType(StringParser::string_to_decimal( + condition.data(), condition.size(), column.precision(), column.frac(), &result)); } }; @@ -195,6 +195,9 @@ std::unique_ptr> get_creator(const FieldType& ty case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: { return std::make_unique>(); } + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: { + return std::make_unique>(); + } case FieldType::OLAP_FIELD_TYPE_CHAR: { return std::make_unique>(); } diff --git a/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp b/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp index 8523740920fec6a..227e91400239130 100644 --- a/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp +++ b/be/src/olap/rowset/segment_v2/bitmap_index_writer.cpp @@ -247,6 +247,9 @@ Status BitmapIndexWriter::create(const TypeInfo* type_info, case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: res->reset(new BitmapIndexWriterImpl(type_info)); break; + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + res->reset(new BitmapIndexWriterImpl(type_info)); + break; case FieldType::OLAP_FIELD_TYPE_BOOL: res->reset(new BitmapIndexWriterImpl(type_info)); break; diff --git a/be/src/olap/rowset/segment_v2/bitshuffle_page.h b/be/src/olap/rowset/segment_v2/bitshuffle_page.h index 05d07acf88462b6..54f446070f106f1 100644 --- a/be/src/olap/rowset/segment_v2/bitshuffle_page.h +++ b/be/src/olap/rowset/segment_v2/bitshuffle_page.h @@ -267,6 +267,7 @@ inline Status parse_bit_shuffle_header(const Slice& data, size_t& num_elements, case 8: case 12: case 16: + case 32: break; default: return Status::InternalError("invalid size_of_elem:{}", size_of_element); diff --git a/be/src/olap/rowset/segment_v2/bloom_filter_index_writer.cpp b/be/src/olap/rowset/segment_v2/bloom_filter_index_writer.cpp index 3afde1340c62e72..e7e3e5e7f6a0623 100644 --- a/be/src/olap/rowset/segment_v2/bloom_filter_index_writer.cpp +++ b/be/src/olap/rowset/segment_v2/bloom_filter_index_writer.cpp @@ -315,6 +315,7 @@ Status BloomFilterIndexWriter::create(const BloomFilterOptions& bf_options, M(FieldType::OLAP_FIELD_TYPE_DECIMAL32) M(FieldType::OLAP_FIELD_TYPE_DECIMAL64) M(FieldType::OLAP_FIELD_TYPE_DECIMAL128I) + M(FieldType::OLAP_FIELD_TYPE_DECIMAL256) #undef M default: return Status::NotSupported("unsupported type for bitmap index: {}", diff --git a/be/src/olap/rowset/segment_v2/encoding_info.cpp b/be/src/olap/rowset/segment_v2/encoding_info.cpp index 573ea925325a986..462b5bdf51c2a44 100644 --- a/be/src/olap/rowset/segment_v2/encoding_info.cpp +++ b/be/src/olap/rowset/segment_v2/encoding_info.cpp @@ -321,6 +321,10 @@ EncodingInfoResolver::EncodingInfoResolver() { _add_map(); _add_map(); + _add_map(); + _add_map(); + _add_map(); + _add_map(); _add_map(); diff --git a/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp b/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp index a9f7daf4b451887..b6682e3ae20e2d0 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index_writer.cpp @@ -624,6 +624,12 @@ Status InvertedIndexColumnWriter::create(const Field* field, field_name, segment_file_name, dir, fs, index_meta); break; } + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: { + *res = std::make_unique< + InvertedIndexColumnWriterImpl>( + field_name, segment_file_name, dir, fs, index_meta); + break; + } case FieldType::OLAP_FIELD_TYPE_BOOL: { *res = std::make_unique>( field_name, segment_file_name, dir, fs, index_meta); diff --git a/be/src/olap/rowset/segment_v2/zone_map_index.cpp b/be/src/olap/rowset/segment_v2/zone_map_index.cpp index 75f0a9d845c8b07..40f755654ef699a 100644 --- a/be/src/olap/rowset/segment_v2/zone_map_index.cpp +++ b/be/src/olap/rowset/segment_v2/zone_map_index.cpp @@ -200,7 +200,8 @@ Status ZoneMapIndexReader::_load(bool use_page_cache, bool kept_in_memory, M(TYPE_STRING) \ M(TYPE_DECIMAL32) \ M(TYPE_DECIMAL64) \ - M(TYPE_DECIMAL128I) + M(TYPE_DECIMAL128I) \ + M(TYPE_DECIMAL256) Status ZoneMapIndexWriter::create(Field* field, std::unique_ptr& res) { switch (field->type()) { diff --git a/be/src/olap/schema.cpp b/be/src/olap/schema.cpp index a3297b4c2faba07..e55b1dcf2aa1c10 100644 --- a/be/src/olap/schema.cpp +++ b/be/src/olap/schema.cpp @@ -199,6 +199,9 @@ vectorized::IColumn::MutablePtr Schema::get_predicate_column_ptr(const Field& fi case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: ptr = doris::vectorized::PredicateColumnType::create(); break; + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + ptr = doris::vectorized::PredicateColumnType::create(); + break; default: LOG(FATAL) << "Unexpected type when choosing predicate column, type=" << int(field.type()); } diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp index 20260f2f4ffa681..c74e94f458805df 100644 --- a/be/src/olap/tablet_schema.cpp +++ b/be/src/olap/tablet_schema.cpp @@ -92,6 +92,8 @@ FieldType TabletColumn::get_field_type_by_string(const std::string& type_str) { type = FieldType::OLAP_FIELD_TYPE_DECIMAL64; } else if (0 == upper_type_str.compare("DECIMAL128I")) { type = FieldType::OLAP_FIELD_TYPE_DECIMAL128I; + } else if (0 == upper_type_str.compare("DECIMAL256")) { + type = FieldType::OLAP_FIELD_TYPE_DECIMAL256; } else if (0 == upper_type_str.compare(0, 7, "DECIMAL")) { type = FieldType::OLAP_FIELD_TYPE_DECIMAL; } else if (0 == upper_type_str.compare(0, 7, "VARCHAR")) { @@ -226,6 +228,9 @@ std::string TabletColumn::get_string_by_field_type(FieldType type) { case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: return "DECIMAL128I"; + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: + return "DECIMAL256"; + case FieldType::OLAP_FIELD_TYPE_VARCHAR: return "VARCHAR"; @@ -351,6 +356,8 @@ uint32_t TabletColumn::get_field_length_by_type(TPrimitiveType::type type, uint3 return 8; case TPrimitiveType::DECIMAL128I: return 16; + case TPrimitiveType::DECIMAL256: + return 32; case TPrimitiveType::DECIMALV2: return 12; // use 12 bytes in olap engine. default: diff --git a/be/src/olap/types.cpp b/be/src/olap/types.cpp index 2c92bd3f2c81b59..b095d830e7079f2 100644 --- a/be/src/olap/types.cpp +++ b/be/src/olap/types.cpp @@ -98,6 +98,7 @@ const TypeInfo* get_scalar_type_info(FieldType field_type) { get_scalar_type_info(), get_scalar_type_info(), get_scalar_type_info(), + get_scalar_type_info(), nullptr}; return field_type_array[int(field_type)]; } diff --git a/be/src/olap/types.h b/be/src/olap/types.h index bb54959aee8b3fd..509e6dde8bfc1d9 100644 --- a/be/src/olap/types.h +++ b/be/src/olap/types.h @@ -53,6 +53,7 @@ #include "util/string_parser.hpp" #include "util/types.h" #include "vec/common/arena.h" +#include "vec/core/wide_integer.h" #include "vec/runtime/vdatetime_value.h" namespace doris { @@ -690,6 +691,11 @@ struct CppTypeTraits { using UnsignedCppType = uint128_t; }; template <> +struct CppTypeTraits { + using CppType = Int256; + using UnsignedCppType = wide::UInt256; +}; +template <> struct CppTypeTraits { using CppType = uint24_t; using UnsignedCppType = uint24_t; @@ -1083,6 +1089,32 @@ struct FieldTypeTraits } }; +template <> +struct FieldTypeTraits + : public BaseFieldtypeTraits { + static Status from_string(void* buf, const std::string& scan_key, const int precision, + const int scale) { + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + auto value = StringParser::string_to_decimal( + scan_key.c_str(), scan_key.size(), 76, scale, &result); + if (result == StringParser::PARSE_FAILURE) { + return Status::Error( + "FieldTypeTraits::from_string meet PARSE_FAILURE"); + } + *reinterpret_cast(buf) = value; + return Status::OK(); + } + static std::string to_string(const void* src) { + // TODO: support decimal256 + DCHECK(false); + return ""; + // auto value = reinterpret_cast(src); + // fmt::memory_buffer buffer; + // fmt::format_to(buffer, "{}", *value); + // return std::string(buffer.data(), buffer.size()); + } +}; + template <> struct FieldTypeTraits : public BaseFieldtypeTraits { diff --git a/be/src/olap/utils.h b/be/src/olap/utils.h index 75df35e32ac8604..f8e2e1fbe9936ee 100644 --- a/be/src/olap/utils.h +++ b/be/src/olap/utils.h @@ -257,6 +257,7 @@ constexpr bool is_numeric_type(const FieldType& field_type) { field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL32 || field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL64 || field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL128I || + field_type == FieldType::OLAP_FIELD_TYPE_DECIMAL256 || field_type == FieldType::OLAP_FIELD_TYPE_BOOL; } diff --git a/be/src/pipeline/exec/scan_operator.cpp b/be/src/pipeline/exec/scan_operator.cpp index df45db62ed88409..fa34f8aa4bbfda2 100644 --- a/be/src/pipeline/exec/scan_operator.cpp +++ b/be/src/pipeline/exec/scan_operator.cpp @@ -228,6 +228,7 @@ Status ScanLocalState::_normalize_conjuncts() { M(DECIMAL32) \ M(DECIMAL64) \ M(DECIMAL128I) \ + M(DECIMAL256) \ M(DECIMALV2) \ M(BOOLEAN) APPLY_FOR_PRIMITIVE_TYPE(M) @@ -891,7 +892,8 @@ Status ScanLocalState::_change_value_range(ColumnValueRange #include "util/hash_util.hpp" +#include "vec/core/wide_integer.h" namespace doris { @@ -140,6 +141,12 @@ class DecimalV2Value { // ATTN: invoker must make sure no OVERFLOW operator int128_t() const { return static_cast(_value / ONE_BILLION); } + operator wide::Int256() const { + wide::Int256 result; + wide::Int256::_impl::wide_integer_from_builtin(result, _value); + return result; + } + operator bool() const { return _value != 0; } operator int8_t() const { return static_cast(operator int64_t()); } diff --git a/be/src/runtime/define_primitive_type.h b/be/src/runtime/define_primitive_type.h index 0ecacb92347c0bb..44a0f2c38ed13e0 100644 --- a/be/src/runtime/define_primitive_type.h +++ b/be/src/runtime/define_primitive_type.h @@ -63,8 +63,9 @@ enum PrimitiveType : PrimitiveNative { TYPE_VARIANT, /* 32 */ TYPE_LAMBDA_FUNCTION, /* 33 */ TYPE_AGG_STATE, /* 34 */ + TYPE_DECIMAL256, /* 35 */ }; constexpr PrimitiveNative BEGIN_OF_PRIMITIVE_TYPE = INVALID_TYPE; -constexpr PrimitiveNative END_OF_PRIMITIVE_TYPE = TYPE_AGG_STATE; +constexpr PrimitiveNative END_OF_PRIMITIVE_TYPE = TYPE_DECIMAL256; } // namespace doris diff --git a/be/src/runtime/fold_constant_executor.cpp b/be/src/runtime/fold_constant_executor.cpp index 793f0209d1f9f27..3a99f0dfc251289 100644 --- a/be/src/runtime/fold_constant_executor.cpp +++ b/be/src/runtime/fold_constant_executor.cpp @@ -234,7 +234,8 @@ string FoldConstantExecutor::_get_result(void* src, size_t size, const TypeDescr } case TYPE_DECIMAL32: case TYPE_DECIMAL64: - case TYPE_DECIMAL128I: { + case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: { return column_type->to_string(*column_ptr, 0); } case TYPE_ARRAY: diff --git a/be/src/runtime/primitive_type.cpp b/be/src/runtime/primitive_type.cpp index 82a189107d93013..84d5d4b0d138504 100644 --- a/be/src/runtime/primitive_type.cpp +++ b/be/src/runtime/primitive_type.cpp @@ -125,6 +125,9 @@ PrimitiveType thrift_to_type(TPrimitiveType::type ttype) { case TPrimitiveType::DECIMAL128I: return TYPE_DECIMAL128I; + case TPrimitiveType::DECIMAL256: + return TYPE_DECIMAL256; + case TPrimitiveType::CHAR: return TYPE_CHAR; @@ -237,6 +240,9 @@ TPrimitiveType::type to_thrift(PrimitiveType ptype) { case TYPE_DECIMAL128I: return TPrimitiveType::DECIMAL128I; + case TYPE_DECIMAL256: + return TPrimitiveType::DECIMAL256; + case TYPE_CHAR: return TPrimitiveType::CHAR; @@ -339,6 +345,9 @@ std::string type_to_string(PrimitiveType t) { case TYPE_DECIMAL128I: return "DECIMAL128I"; + case TYPE_DECIMAL256: + return "DECIMAL256"; + case TYPE_CHAR: return "CHAR"; @@ -445,6 +454,9 @@ std::string type_to_odbc_string(PrimitiveType t) { case TYPE_DECIMAL128I: return "decimal128"; + case TYPE_DECIMAL256: + return "decimal256"; + case TYPE_CHAR: return "char"; diff --git a/be/src/runtime/primitive_type.h b/be/src/runtime/primitive_type.h index 2f331b109299be1..997e9b611b50a19 100644 --- a/be/src/runtime/primitive_type.h +++ b/be/src/runtime/primitive_type.h @@ -59,6 +59,7 @@ constexpr bool is_enumeration_type(PrimitiveType type) { case TYPE_DECIMAL32: case TYPE_DECIMAL64: case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: case TYPE_BOOLEAN: case TYPE_ARRAY: case TYPE_STRUCT: @@ -205,6 +206,11 @@ struct PrimitiveTypeTraits { using ColumnType = vectorized::ColumnDecimal; }; template <> +struct PrimitiveTypeTraits { + using CppType = vectorized::Decimal256; + using ColumnType = vectorized::ColumnDecimal; +}; +template <> struct PrimitiveTypeTraits { using CppType = __int128_t; using ColumnType = vectorized::ColumnInt128; diff --git a/be/src/runtime/raw_value.h b/be/src/runtime/raw_value.h index b1e635ea1693a1a..1990d11023d4eee 100644 --- a/be/src/runtime/raw_value.h +++ b/be/src/runtime/raw_value.h @@ -103,6 +103,8 @@ inline uint32_t RawValue::zlib_crc32(const void* v, size_t len, const PrimitiveT return HashUtil::zlib_crc_hash(v, 8, seed); case TYPE_DECIMAL128I: return HashUtil::zlib_crc_hash(v, 16, seed); + case TYPE_DECIMAL256: + return HashUtil::zlib_crc_hash(v, 32, seed); default: DCHECK(false) << "invalid type: " << type; return 0; diff --git a/be/src/runtime/runtime_predicate.cpp b/be/src/runtime/runtime_predicate.cpp index f053b842c7be179..2b949fb10e6c534 100644 --- a/be/src/runtime/runtime_predicate.cpp +++ b/be/src/runtime/runtime_predicate.cpp @@ -112,6 +112,10 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) _get_value_fn = get_decimal128_value; break; } + case PrimitiveType::TYPE_DECIMAL256: { + _get_value_fn = get_decimal256_value; + break; + } default: return Status::InvalidArgument("unsupported runtime predicate type {}", type); } diff --git a/be/src/runtime/runtime_predicate.h b/be/src/runtime/runtime_predicate.h index 9dd48279acbfe68..b1d4dadf1a4fa8a 100644 --- a/be/src/runtime/runtime_predicate.h +++ b/be/src/runtime/runtime_predicate.h @@ -173,6 +173,12 @@ class RuntimePredicate { auto v = field.get>(); return cast_to_string(v.get_value(), v.get_scale()); } + + static std::string get_decimal256_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits::CppType; + auto v = field.get>(); + return cast_to_string(v.get_value(), v.get_scale()); + } }; } // namespace vectorized diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index 06ceb699e7fd12b..1e9ec6a8ff86f17 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -146,6 +146,10 @@ class RuntimeState { _query_options.check_overflow_for_decimal; } + bool enable_decima256() const { + return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256; + } + bool enable_common_expr_pushdown() const { return _query_options.__isset.enable_common_expr_pushdown && _query_options.enable_common_expr_pushdown; diff --git a/be/src/runtime/type_limit.h b/be/src/runtime/type_limit.h index 4d9fd5e646a4c04..d23a9f1921656c4 100644 --- a/be/src/runtime/type_limit.h +++ b/be/src/runtime/type_limit.h @@ -20,6 +20,7 @@ #include "runtime/datetime_value.h" #include "runtime/decimalv2_value.h" #include "vec/common/string_ref.h" +#include "vec/core/wide_integer.h" namespace doris { @@ -70,6 +71,13 @@ struct type_limit { } static vectorized::Decimal128 min() { return -max(); } }; +static Int256 MAX_DECIMAL256_INT({18446744073709551615ul, 8607968719199866879ul, + 532749306367912313ul, 1593091911132452277ul}); +template <> +struct type_limit { + static vectorized::Decimal256 max() { return vectorized::Decimal256(MAX_DECIMAL256_INT); } + static vectorized::Decimal256 min() { return vectorized::Decimal256(-MAX_DECIMAL256_INT); } +}; template <> struct type_limit { diff --git a/be/src/runtime/types.cpp b/be/src/runtime/types.cpp index 4cb3d3ef5bb6308..10a6b47f84c6901 100644 --- a/be/src/runtime/types.cpp +++ b/be/src/runtime/types.cpp @@ -46,7 +46,8 @@ TypeDescriptor::TypeDescriptor(const std::vector& types, int* idx) DCHECK(scalar_type.__isset.len); len = scalar_type.len; } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || - type == TYPE_DECIMAL128I || type == TYPE_DATETIMEV2 || type == TYPE_TIMEV2) { + type == TYPE_DECIMAL128I || type == TYPE_DECIMAL256 || type == TYPE_DATETIMEV2 || + type == TYPE_TIMEV2) { DCHECK(scalar_type.__isset.precision); DCHECK(scalar_type.__isset.scale); precision = scalar_type.precision; @@ -151,7 +152,7 @@ void TypeDescriptor::to_thrift(TTypeDesc* thrift_type) const { // DCHECK_NE(len, -1); scalar_type.__set_len(len); } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || - type == TYPE_DECIMAL128I || type == TYPE_DATETIMEV2) { + type == TYPE_DECIMAL128I || type == TYPE_DECIMAL256 || type == TYPE_DATETIMEV2) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type.__set_precision(precision); @@ -168,7 +169,7 @@ void TypeDescriptor::to_protobuf(PTypeDesc* ptype) const { if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL || type == TYPE_STRING) { scalar_type->set_len(len); } else if (type == TYPE_DECIMALV2 || type == TYPE_DECIMAL32 || type == TYPE_DECIMAL64 || - type == TYPE_DECIMAL128I || type == TYPE_DATETIMEV2) { + type == TYPE_DECIMAL128I || type == TYPE_DECIMAL256 || type == TYPE_DATETIMEV2) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type->set_precision(precision); @@ -218,7 +219,7 @@ TypeDescriptor::TypeDescriptor(const google::protobuf::RepeatedPtrField"; return ss.str(); diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index bb030b66d6b2481..4cb7d51e4b5600d 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -29,6 +29,7 @@ #include #include "common/config.h" +#include "common/consts.h" #include "runtime/define_primitive_type.h" namespace doris { @@ -50,15 +51,6 @@ struct TypeDescriptor { int precision; int scale; - /// Must be kept in sync with FE's max precision/scale. - static constexpr int MAX_PRECISION = 38; - static constexpr int MAX_SCALE = MAX_PRECISION; - - /// The maximum precision representable by a 4-byte decimal (Decimal4Value) - static constexpr int MAX_DECIMAL4_PRECISION = 9; - /// The maximum precision representable by a 8-byte decimal (Decimal8Value) - static constexpr int MAX_DECIMAL8_PRECISION = 18; - std::vector children; bool result_is_nullable = false; @@ -118,8 +110,8 @@ struct TypeDescriptor { } static TypeDescriptor create_decimalv2_type(int precision, int scale) { - DCHECK_LE(precision, MAX_PRECISION); - DCHECK_LE(scale, MAX_SCALE); + DCHECK_LE(precision, BeConsts::MAX_DECIMALV2_PRECISION); + DCHECK_LE(scale, BeConsts::MAX_DECIMALV2_SCALE); DCHECK_GE(precision, 0); DCHECK_LE(scale, precision); TypeDescriptor ret; @@ -130,17 +122,19 @@ struct TypeDescriptor { } static TypeDescriptor create_decimalv3_type(int precision, int scale) { - DCHECK_LE(precision, MAX_PRECISION); - DCHECK_LE(scale, MAX_SCALE); + DCHECK_LE(precision, BeConsts::MAX_DECIMALV3_PRECISION); + DCHECK_LE(scale, BeConsts::MAX_DECIMALV3_SCALE); DCHECK_GE(precision, 0); DCHECK_LE(scale, precision); TypeDescriptor ret; - if (precision <= MAX_DECIMAL4_PRECISION) { + if (precision <= BeConsts::MAX_DECIMAL32_PRECISION) { ret.type = TYPE_DECIMAL32; - } else if (precision <= MAX_DECIMAL8_PRECISION) { + } else if (precision <= BeConsts::MAX_DECIMAL64_PRECISION) { ret.type = TYPE_DECIMAL64; - } else { + } else if (precision <= BeConsts::MAX_DECIMAL128_PRECISION) { ret.type = TYPE_DECIMAL128I; + } else { + ret.type = TYPE_DECIMAL256; } ret.precision = precision; ret.scale = scale; @@ -216,7 +210,8 @@ struct TypeDescriptor { bool is_decimal_v2_type() const { return type == TYPE_DECIMALV2; } bool is_decimal_v3_type() const { - return (type == TYPE_DECIMAL32) || (type == TYPE_DECIMAL64) || (type == TYPE_DECIMAL128I); + return (type == TYPE_DECIMAL32) || (type == TYPE_DECIMAL64) || (type == TYPE_DECIMAL128I) || + (type == TYPE_DECIMAL256); } bool is_datetime_type() const { return type == TYPE_DATETIME; } @@ -244,13 +239,16 @@ struct TypeDescriptor { static inline int get_decimal_byte_size(int precision) { DCHECK_GT(precision, 0); - if (precision <= MAX_DECIMAL4_PRECISION) { + if (precision <= BeConsts::MAX_DECIMAL32_PRECISION) { return 4; } - if (precision <= MAX_DECIMAL8_PRECISION) { + if (precision <= BeConsts::MAX_DECIMAL64_PRECISION) { return 8; } - return 16; + if (precision <= BeConsts::MAX_DECIMAL128_PRECISION) { + return 16; + } + return 32; } std::string debug_string() const; diff --git a/be/src/util/binary_cast.hpp b/be/src/util/binary_cast.hpp index ecba899ec220ee8..43ea6486bd476c1 100644 --- a/be/src/util/binary_cast.hpp +++ b/be/src/util/binary_cast.hpp @@ -24,6 +24,7 @@ #include "runtime/datetime_value.h" #include "runtime/decimalv2_value.h" #include "util/types.h" +#include "vec/core/wide_integer.h" #include "vec/runtime/vdatetime_value.h" namespace doris { union TypeConverter { @@ -79,6 +80,7 @@ To binary_cast(From from) { match_v; constexpr bool from_i128_to_decv2 = match_v; constexpr bool from_decv2_to_i128 = match_v; + constexpr bool from_decv2_to_i256 = match_v; constexpr bool from_ui32_to_date_v2 = match_v static T get_scale_multiplier(int scale) { static_assert(std::is_same_v || std::is_same_v || - std::is_same_v, + std::is_same_v || std::is_same_v, "You can only instantiate as int32_t, int64_t, __int128."); if constexpr (std::is_same_v) { return common::exp10_i32(scale); @@ -99,6 +100,8 @@ class StringParser { return common::exp10_i64(scale); } else if constexpr (std::is_same_v) { return common::exp10_i128(scale); + } else if constexpr (std::is_same_v) { + return common::exp10_i256(scale); } } diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 2e58ccb56ea4767..8895e39908d64a0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -35,7 +35,18 @@ struct Avg { template using AggregateFuncAvg = typename Avg::Function; +template +struct AvgDecimal256 { + using FieldType = typename AvgNearestFieldTypeTrait::Type; + using Function = AggregateFunctionAvg>; +}; + +template +using AggregateFuncAvgDecimal256 = typename Avg::Function; + void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("avg", creator_with_type::creator); + factory.register_function_both("avg_decimal256", + creator_with_type::creator); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index 9697658ec70ae11..7d388edf89a104c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -58,6 +58,7 @@ namespace doris::vectorized { template struct AggregateFunctionAvgData { + using ResultType = T; T sum {}; UInt64 count = 0; @@ -87,7 +88,11 @@ struct AggregateFunctionAvgData { Decimal128 ret(cal_ret.value()); return ret; } else { - return static_cast(sum) / count; + if constexpr (IsDecimal256) { + return static_cast(sum / T(count)); + } else { + return static_cast(sum) / count; + } } } @@ -107,16 +112,25 @@ template class AggregateFunctionAvg final : public IAggregateFunctionDataHelper> { public: + /* using ResultType = DisposeDecimal; using ResultDataType = std::conditional_t, DataTypeDecimal, std::conditional_t, DataTypeDecimal, DataTypeNumber>>; + */ + using ResultType = std::conditional_t< + IsDecimalV2, Decimal128, + std::conditional_t, typename Data::ResultType, Float64>>; + using ResultDataType = std::conditional_t< + IsDecimalV2, DataTypeDecimal, + std::conditional_t, DataTypeDecimal, + DataTypeNumber>>; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; - using ColVecResult = - std::conditional_t, ColumnDecimal, - std::conditional_t, ColumnDecimal, - ColumnVector>>; + using ColVecResult = std::conditional_t< + IsDecimalV2, ColumnDecimal, + std::conditional_t, ColumnDecimal, + ColumnVector>>; /// ctor for native types AggregateFunctionAvg(const DataTypes& argument_types_) diff --git a/be/src/vec/aggregate_functions/aggregate_function_product.h b/be/src/vec/aggregate_functions/aggregate_function_product.h index ca1c7c9fae9e156..8a13ad0b0340271 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_product.h +++ b/be/src/vec/aggregate_functions/aggregate_function_product.h @@ -134,7 +134,7 @@ class AggregateFunctionProduct final void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); - this->data(place).add(column.get_data()[row_num], multiplier); + this->data(place).add(TResult(column.get_data()[row_num]), multiplier); } void reset(AggregateDataPtr place) const override { diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 618340dd880acf8..dccbd9a4d575fa4 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -80,7 +80,8 @@ class AggregateFunctionSimpleFactory { AggregateFunctionPtr get(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable = false, - int be_version = BeExecVersionManager::get_newest_version()) { + int be_version = BeExecVersionManager::get_newest_version(), + bool enable_decima256 = false) { bool nullable = false; for (const auto& type : argument_types) { if (type->is_nullable()) { @@ -89,6 +90,11 @@ class AggregateFunctionSimpleFactory { } std::string name_str = name; + if (enable_decima256) { + if (name_str == "sum" || name_str == "avg") { + name_str += "_decimal256"; + } + } temporary_function_update(be_version, name_str); if (function_alias.count(name)) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp index ede242519849b71..3ee7dc6ff483331 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp @@ -27,6 +27,8 @@ namespace doris::vectorized { void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("sum", creator_with_type::creator); + factory.register_function_both( + "sum_decimal256", creator_with_type::creator); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h b/be/src/vec/aggregate_functions/aggregate_function_sum.h index 9f58023d507ff3b..41677dd419bf2fb 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h @@ -101,7 +101,7 @@ class AggregateFunctionSum final void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { const auto& column = assert_cast(*columns[0]); - this->data(place).add(column.get_data()[row_num]); + this->data(place).add(TResult(column.get_data()[row_num])); } void reset(AggregateDataPtr place) const override { this->data(place).sum = {}; } @@ -156,7 +156,7 @@ class AggregateFunctionSum final auto* dst_data = col.get_data().data(); for (size_t i = 0; i != num_rows; ++i) { auto& state = *reinterpret_cast(&dst_data[sizeof(Data) * i]); - state.sum = src_data[i]; + state.sum = TResult(src_data[i]); } } @@ -231,6 +231,18 @@ struct SumSimple { template using AggregateFunctionSumSimple = typename SumSimple::Function; +const static std::string DECIMAL256_SUFFIX {"_decimal256"}; +template +struct SumSimpleDecimal256 { + /// @note It uses slow Decimal128 (cause we need such a variant). sumWithOverflow is faster for Decimal32/64 + using ResultType = std::conditional_t>, T>; + using AggregateDataType = AggregateFunctionSumData; + using Function = AggregateFunctionSum; +}; + +template +using AggregateFunctionSumSimpleDecimal256 = typename SumSimpleDecimal256::Function; + // do not level up return type for agg reader template using AggregateFunctionSumSimpleReader = typename SumSimple::Function; diff --git a/be/src/vec/aggregate_functions/helpers.h b/be/src/vec/aggregate_functions/helpers.h index f50524085cb46eb..58ddd455bc850cf 100644 --- a/be/src/vec/aggregate_functions/helpers.h +++ b/be/src/vec/aggregate_functions/helpers.h @@ -45,7 +45,8 @@ M(Decimal32) \ M(Decimal64) \ M(Decimal128) \ - M(Decimal128I) + M(Decimal128I) \ + M(Decimal256) /** If the serialized type is not the default type(string), * aggregation function need to override these functions: diff --git a/be/src/vec/columns/column_array.h b/be/src/vec/columns/column_array.h index 668abd1ef62f1f2..6ece9a4bbe4dd8f 100644 --- a/be/src/vec/columns/column_array.h +++ b/be/src/vec/columns/column_array.h @@ -52,9 +52,10 @@ class Arena; } // namespace doris //TODO: use marcos below to decouple array function calls -#define ALL_COLUMNS_NUMBER \ - ColumnUInt8, ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnInt128, ColumnFloat32, \ - ColumnFloat64, ColumnDecimal32, ColumnDecimal64, ColumnDecimal128I, ColumnDecimal128 +#define ALL_COLUMNS_NUMBER \ + ColumnUInt8, ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, ColumnInt128, ColumnFloat32, \ + ColumnFloat64, ColumnDecimal32, ColumnDecimal64, ColumnDecimal128I, ColumnDecimal128, \ + ColumnDecimal256 #define ALL_COLUMNS_TIME ColumnDate, ColumnDateTime, ColumnDateV2, ColumnDateTimeV2 #define ALL_COLUMNS_NUMERIC ALL_COLUMNS_NUMBER, ALL_COLUMNS_TIME #define ALL_COLUMNS_SIMPLE ALL_COLUMNS_NUMERIC, ColumnString diff --git a/be/src/vec/columns/column_decimal.cpp b/be/src/vec/columns/column_decimal.cpp index edc8a5777fff868..c4d94d91a4e8a20 100644 --- a/be/src/vec/columns/column_decimal.cpp +++ b/be/src/vec/columns/column_decimal.cpp @@ -121,8 +121,9 @@ template UInt64 ColumnDecimal::get64(size_t n) const { if constexpr (sizeof(T) > sizeof(UInt64)) { LOG(FATAL) << "Method get64 is not supported for " << get_family_name(); + } else { + return static_cast(data[n]); } - return static_cast(data[n]); } template @@ -501,6 +502,13 @@ Decimal128I ColumnDecimal::get_scale_multiplier() const { return common::exp10_i128(scale); } +// duplicate with +// Decimal256 DataTypeDecimal::get_scale_multiplier(UInt32 scale) { +template <> +Decimal256 ColumnDecimal::get_scale_multiplier() const { + return Decimal256(common::exp10_i256(scale)); +} + template ColumnPtr ColumnDecimal::index(const IColumn& indexes, size_t limit) const { return select_index_impl(*this, indexes, limit); @@ -510,4 +518,5 @@ template class ColumnDecimal; template class ColumnDecimal; template class ColumnDecimal; template class ColumnDecimal; +template class ColumnDecimal; } // namespace doris::vectorized diff --git a/be/src/vec/columns/column_decimal.h b/be/src/vec/columns/column_decimal.h index 85ce339608d96b5..b0ca5250bd51caf 100644 --- a/be/src/vec/columns/column_decimal.h +++ b/be/src/vec/columns/column_decimal.h @@ -200,7 +200,7 @@ class ColumnDecimal final : public COWHelper; using ColumnDecimal64 = ColumnDecimal; using ColumnDecimal128 = ColumnDecimal; using ColumnDecimal128I = ColumnDecimal; +using ColumnDecimal256 = ColumnDecimal; template struct IsFixLenColumnType { diff --git a/be/src/vec/common/arithmetic_overflow.h b/be/src/vec/common/arithmetic_overflow.h index 0d0828a61bfa875..b4b55eb47a510f4 100644 --- a/be/src/vec/common/arithmetic_overflow.h +++ b/be/src/vec/common/arithmetic_overflow.h @@ -20,6 +20,7 @@ #pragma once +#include "vec/core/wide_integer.h" namespace common { template inline bool add_overflow(T x, T y, T& res) { @@ -50,6 +51,13 @@ inline bool add_overflow(__int128 x, __int128 y, __int128& res) { return (y > 0 && x > max_int128 - y) || (y < 0 && x < min_int128 - y); } +template <> +inline bool add_overflow(wide::Int256 x, wide::Int256 y, wide::Int256& res) { + static constexpr wide::Int256 min_int256 = std::numeric_limits::min(); + static constexpr wide::Int256 max_int256 = std::numeric_limits::max(); + res = x + y; + return (y > 0 && x > max_int256 - y) || (y < 0 && x < min_int256 - y); +} template inline bool sub_overflow(T x, T y, T& res) { return __builtin_sub_overflow(x, y, &res); @@ -79,6 +87,14 @@ inline bool sub_overflow(__int128 x, __int128 y, __int128& res) { return (y < 0 && x > max_int128 + y) || (y > 0 && x < min_int128 + y); } +template <> +inline bool sub_overflow(wide::Int256 x, wide::Int256 y, wide::Int256& res) { + static constexpr wide::Int256 min_int256 = std::numeric_limits::min(); + static constexpr wide::Int256 max_int256 = std::numeric_limits::max(); + res = x - y; + return (y < 0 && x > max_int256 + y) || (y > 0 && x < min_int256 + y); +} + template inline bool mul_overflow(T x, T y, T& res) { return __builtin_mul_overflow(x, y, &res); @@ -109,4 +125,13 @@ inline bool mul_overflow(__int128 x, __int128 y, __int128& res) { unsigned __int128 b = (y > 0) ? y : -y; return (a * b) / b != a; } + +template <> +inline bool mul_overflow(wide::Int256 x, wide::Int256 y, wide::Int256& res) { + res = x * y; + if (!x || !y) return false; + wide::UInt256 a = (x > 0) ? x : -x; + wide::UInt256 b = (y > 0) ? y : -y; + return (a * b) / b != a; +} } // namespace common diff --git a/be/src/vec/common/field_visitors.h b/be/src/vec/common/field_visitors.h index 68a85170d4519ea..8434483b7721eaf 100644 --- a/be/src/vec/common/field_visitors.h +++ b/be/src/vec/common/field_visitors.h @@ -63,6 +63,8 @@ typename std::decay_t::ResultType apply_visitor(Visitor&& visitor, F&& return visitor(field.template get>()); case Field::Types::Decimal128I: return visitor(field.template get>()); + case Field::Types::Decimal256: + return visitor(field.template get>()); default: LOG(FATAL) << "Bad type of Field"; return {}; diff --git a/be/src/vec/common/hash_table/hash.h b/be/src/vec/common/hash_table/hash.h index 3c7df75b0a55147..9556bf87a0718b4 100644 --- a/be/src/vec/common/hash_table/hash.h +++ b/be/src/vec/common/hash_table/hash.h @@ -26,6 +26,7 @@ #include "vec/common/string_ref.h" #include "vec/common/uint128.h" #include "vec/core/types.h" +#include "vec/core/wide_integer.h" // Here is an empirical value. static constexpr size_t HASH_MAP_PREFETCH_DIST = 16; @@ -94,6 +95,9 @@ struct DefaultHash { template <> struct DefaultHash : public doris::StringRefHash {}; +template <> +struct DefaultHash : public std::hash {}; + template struct HashCRC32; @@ -163,6 +167,23 @@ struct HashCRC32 { } }; +template <> +struct HashCRC32 { + size_t operator()(const wide::Int256& x) const { +#if defined(__SSE4_2__) || defined(__aarch64__) + doris::vectorized::UInt64 crc = -1ULL; + crc = _mm_crc32_u64(crc, x.items[0]); + crc = _mm_crc32_u64(crc, x.items[1]); + crc = _mm_crc32_u64(crc, x.items[2]); + crc = _mm_crc32_u64(crc, x.items[3]); + return crc; +#else + return Hash128to64( + {Hash128to64({x.items[0], x.items[1]}), Hash128to64({x.items[2], x.items[3]})}); +#endif + } +}; + template <> struct HashCRC32 { size_t operator()(const doris::vectorized::UInt136& x) const { diff --git a/be/src/vec/common/int_exp.h b/be/src/vec/common/int_exp.h index cac7f24f0404e4a..81ca11bb11a10c8 100644 --- a/be/src/vec/common/int_exp.h +++ b/be/src/vec/common/int_exp.h @@ -24,6 +24,8 @@ #include #include +#include "vec/core/wide_integer.h" + namespace exp_details { // compile-time exp(v, n) by linear recursion @@ -78,4 +80,94 @@ inline constexpr __int128 exp10_i128(int x) { return exp_details::get_exp<__int128, 10, 39>(x); } +using wide::Int256; +inline Int256 exp10_i256(int x) { + if (x < 0) return 0; + if (x > 76) return std::numeric_limits::max(); + + using Int256 = Int256; + static constexpr Int256 i10e18 {1000000000000000000ll}; + static const Int256 values[] = { + static_cast(1ll), + static_cast(10ll), + static_cast(100ll), + static_cast(1000ll), + static_cast(10000ll), + static_cast(100000ll), + static_cast(1000000ll), + static_cast(10000000ll), + static_cast(100000000ll), + static_cast(1000000000ll), + static_cast(10000000000ll), + static_cast(100000000000ll), + static_cast(1000000000000ll), + static_cast(10000000000000ll), + static_cast(100000000000000ll), + static_cast(1000000000000000ll), + static_cast(10000000000000000ll), + static_cast(100000000000000000ll), + i10e18, + i10e18 * 10ll, + i10e18 * 100ll, + i10e18 * 1000ll, + i10e18 * 10000ll, + i10e18 * 100000ll, + i10e18 * 1000000ll, + i10e18 * 10000000ll, + i10e18 * 100000000ll, + i10e18 * 1000000000ll, + i10e18 * 10000000000ll, + i10e18 * 100000000000ll, + i10e18 * 1000000000000ll, + i10e18 * 10000000000000ll, + i10e18 * 100000000000000ll, + i10e18 * 1000000000000000ll, + i10e18 * 10000000000000000ll, + i10e18 * 100000000000000000ll, + i10e18 * 100000000000000000ll * 10ll, + i10e18 * 100000000000000000ll * 100ll, + i10e18 * 100000000000000000ll * 1000ll, + i10e18 * 100000000000000000ll * 10000ll, + i10e18 * 100000000000000000ll * 100000ll, + i10e18 * 100000000000000000ll * 1000000ll, + i10e18 * 100000000000000000ll * 10000000ll, + i10e18 * 100000000000000000ll * 100000000ll, + i10e18 * 100000000000000000ll * 1000000000ll, + i10e18 * 100000000000000000ll * 10000000000ll, + i10e18 * 100000000000000000ll * 100000000000ll, + i10e18 * 100000000000000000ll * 1000000000000ll, + i10e18 * 100000000000000000ll * 10000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000ll, + i10e18 * 100000000000000000ll * 1000000000000000ll, + i10e18 * 100000000000000000ll * 10000000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 1000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 10ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 100ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 1000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 10000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 100000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 1000000ll, + i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * + 10000000ll, + }; + return values[x]; +} + } // namespace common diff --git a/be/src/vec/core/accurate_comparison.h b/be/src/vec/core/accurate_comparison.h index 10e961311ca9c30..6e32a0d33bd4749 100644 --- a/be/src/vec/core/accurate_comparison.h +++ b/be/src/vec/core/accurate_comparison.h @@ -28,6 +28,8 @@ #include "vec/common/nan_utils.h" #include "vec/common/string_ref.h" #include "vec/common/uint128.h" +#include "vec/core/decomposed_float.h" +#include "vec/core/extended_types.h" #include "vec/core/types.h" #include "vec/runtime/vdatetime_value.h" /** Perceptually-correct number comparisons. @@ -161,6 +163,7 @@ inline bool_if_double_can_be_used equalsOpTmpl(TAFloat a, TAInt /* Final realizations */ +/* template inline bool_if_not_safe_conversion greaterOp(A a, B b) { return greaterOpTmpl(a, b); @@ -464,6 +467,114 @@ template inline bool_if_safe_conversion greaterOrEqualsOp(A a, B b) { return a >= b; } +*/ + +template +bool lessOp(A a, B b) { + if constexpr (std::is_same_v) return a < b; + + /// float vs float + if constexpr (std::is_floating_point_v && std::is_floating_point_v) return a < b; + + /// anything vs NaN + if (is_nan(a) || is_nan(b)) return false; + + /// int vs int + if constexpr (is_integer && is_integer) { + /// same signedness + if constexpr (is_signed_v == is_signed_v) return a < b; + + /// different signedness + + if constexpr (is_signed_v && !is_signed_v) + return a < 0 || static_cast>(a) < b; + + if constexpr (!is_signed_v && is_signed_v) + return b >= 0 && a < static_cast>(b); + } + + /// int vs float + if constexpr (is_integer && std::is_floating_point_v) { + if constexpr (sizeof(A) <= 4) return static_cast(a) < static_cast(b); + + return DecomposedFloat(b).greater(a); + } + + if constexpr (std::is_floating_point_v && is_integer) { + if constexpr (sizeof(B) <= 4) return static_cast(a) < static_cast(b); + + return DecomposedFloat(a).less(b); + } + + static_assert(is_integer || std::is_floating_point_v); + static_assert(is_integer || std::is_floating_point_v); + __builtin_unreachable(); +} + +template +bool greaterOp(A a, B b) { + return lessOp(b, a); +} + +template +bool greaterOrEqualsOp(A a, B b) { + if (is_nan(a) || is_nan(b)) return false; + + return !lessOp(a, b); +} + +template +bool lessOrEqualsOp(A a, B b) { + if (is_nan(a) || is_nan(b)) return false; + + return !lessOp(b, a); +} + +template +bool equalsOp(A a, B b) { + if constexpr (std::is_same_v) return a == b; + + /// float vs float + if constexpr (std::is_floating_point_v && std::is_floating_point_v) return a == b; + + /// anything vs NaN + if (is_nan(a) || is_nan(b)) return false; + + /// int vs int + if constexpr (is_integer && is_integer) { + /// same signedness + if constexpr (is_signed_v == is_signed_v) return a == b; + + /// different signedness + + if constexpr (is_signed_v && !is_signed_v) + return a >= 0 && static_cast>(a) == b; + + if constexpr (!is_signed_v && is_signed_v) + return b >= 0 && a == static_cast>(b); + } + + /// int vs float + if constexpr (is_integer && std::is_floating_point_v) { + if constexpr (sizeof(A) <= 4) return static_cast(a) == static_cast(b); + + return DecomposedFloat(b).equals(a); + } + + if constexpr (std::is_floating_point_v && is_integer) { + if constexpr (sizeof(B) <= 4) return static_cast(a) == static_cast(b); + + return DecomposedFloat(a).equals(b); + } + + /// e.g comparing UUID with integer. + return false; +} + +template +bool notEqualsOp(A a, B b) { + return !equalsOp(a, b); +} /// Converts numeric to an equal numeric of other type. /// When `strict` is `true` check that result exactly same as input, otherwise just check overflow diff --git a/be/src/vec/core/call_on_type_index.h b/be/src/vec/core/call_on_type_index.h index 283f7aeb078c640..ecc595f88ad3592 100644 --- a/be/src/vec/core/call_on_type_index.h +++ b/be/src/vec/core/call_on_type_index.h @@ -72,6 +72,8 @@ bool call_on_basic_type(TypeIndex number, F&& f) { return f(TypePair()); case TypeIndex::Decimal128I: return f(TypePair()); + case TypeIndex::Decimal256: + return f(TypePair()); default: break; } @@ -143,6 +145,9 @@ bool call_on_basic_types(TypeIndex type_num1, TypeIndex type_num2, F&& f) { case TypeIndex::Decimal128I: return call_on_basic_type( type_num2, std::forward(f)); + case TypeIndex::Decimal256: + return call_on_basic_type( + type_num2, std::forward(f)); default: break; } @@ -215,6 +220,8 @@ bool call_on_index_and_data_type(TypeIndex number, F&& f) { return f(TypePair, T>()); case TypeIndex::Decimal128I: return f(TypePair, T>()); + case TypeIndex::Decimal256: + return f(TypePair, T>()); case TypeIndex::Date: return f(TypePair()); @@ -270,6 +277,8 @@ bool call_on_index_and_number_data_type(TypeIndex number, F&& f) { return f(TypePair, T>()); case TypeIndex::Decimal128I: return f(TypePair, T>()); + case TypeIndex::Decimal256: + return f(TypePair, T>()); default: break; } diff --git a/be/src/vec/core/decimal_comparison.h b/be/src/vec/core/decimal_comparison.h index 68a083dc159e3e2..3484f546ae108d4 100644 --- a/be/src/vec/core/decimal_comparison.h +++ b/be/src/vec/core/decimal_comparison.h @@ -27,6 +27,7 @@ #include "vec/core/accurate_comparison.h" #include "vec/core/block.h" #include "vec/core/call_on_type_index.h" +#include "vec/core/types.h" #include "vec/data_types/data_type_decimal.h" #include "vec/functions/function_helpers.h" /// todo core should not depend on function" @@ -53,6 +54,10 @@ template <> struct ConstructDecInt<16> { using Type = Int128; }; +template <> +struct ConstructDecInt<32> { + using Type = Int256; +}; template struct DecCompareInt { @@ -99,18 +104,22 @@ class DecimalComparison { } static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b) { - static const UInt32 max_scale = max_decimal_precision(); + static const UInt32 max_scale = max_decimal_precision(); if (scale_a > max_scale || scale_b > max_scale) { LOG(FATAL) << "Bad scale of decimal field"; } Shift shift; - if (scale_a < scale_b) + if (scale_a < scale_b) { shift.a = DataTypeDecimal(max_decimal_precision(), scale_b) - .get_scale_multiplier(scale_b - scale_a); - if (scale_a > scale_b) + .get_scale_multiplier(scale_b - scale_a) + .value; + } + if (scale_a > scale_b) { shift.b = DataTypeDecimal(max_decimal_precision(), scale_a) - .get_scale_multiplier(scale_a - scale_b); + .get_scale_multiplier(scale_a - scale_b) + .value; + } return apply_with_scale(a, b, shift); } @@ -145,12 +154,12 @@ class DecimalComparison { using Type = std::conditional_t= sizeof(U), T, U>; auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false); const DataTypeDecimal* result_type = check_decimal(*type_ptr); - shift.a = result_type->scale_factor_for(*decimal0, false); - shift.b = result_type->scale_factor_for(*decimal1, false); + shift.a = result_type->scale_factor_for(*decimal0, false).value; + shift.b = result_type->scale_factor_for(*decimal1, false).value; } else if (decimal0) { - shift.b = decimal0->get_scale_multiplier(); + shift.b = decimal0->get_scale_multiplier().value; } else if (decimal1) { - shift.a = decimal1->get_scale_multiplier(); + shift.a = decimal1->get_scale_multiplier().value; } return shift; @@ -161,7 +170,7 @@ class DecimalComparison { static Shift getScales(const DataTypePtr& left_type, const DataTypePtr&) { Shift shift; const DataTypeDecimal* decimal0 = check_decimal(*left_type); - if (decimal0) shift.b = decimal0->get_scale_multiplier(); + if (decimal0) shift.b = decimal0->get_scale_multiplier().value; return shift; } @@ -170,7 +179,7 @@ class DecimalComparison { static Shift getScales(const DataTypePtr&, const DataTypePtr& right_type) { Shift shift; const DataTypeDecimal* decimal1 = check_decimal(*right_type); - if (decimal1) shift.a = decimal1->get_scale_multiplier(); + if (decimal1) shift.a = decimal1->get_scale_multiplier().value; return shift; } diff --git a/be/src/vec/core/decomposed_float.h b/be/src/vec/core/decomposed_float.h new file mode 100644 index 000000000000000..5000748600365ab --- /dev/null +++ b/be/src/vec/core/decomposed_float.h @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/base/base/DecomposedFloat.h +// and modified by Doris +#pragma once + +#include +#include +#include + +#include "extended_types.h" + +/// Allows to check the internals of IEEE-754 floating point number. + +template +struct FloatTraits; + +template <> +struct FloatTraits { + using UInt = uint32_t; + static constexpr size_t bits = 32; + static constexpr size_t exponent_bits = 8; + static constexpr size_t mantissa_bits = bits - exponent_bits - 1; +}; + +template <> +struct FloatTraits { + using UInt = uint64_t; + static constexpr size_t bits = 64; + static constexpr size_t exponent_bits = 11; + static constexpr size_t mantissa_bits = bits - exponent_bits - 1; +}; + +/// x = sign * (2 ^ normalized_exponent) * (1 + mantissa * 2 ^ -mantissa_bits) +/// x = sign * (2 ^ normalized_exponent + mantissa * 2 ^ (normalized_exponent - mantissa_bits)) +template +struct DecomposedFloat { + using Traits = FloatTraits; + + explicit DecomposedFloat(T x) { memcpy(&x_uint, &x, sizeof(x)); } + + typename Traits::UInt x_uint; + + bool isNegative() const { return x_uint >> (Traits::bits - 1); } + + /// Returns 0 for both +0. and -0. + int sign() const { return (exponent() == 0 && mantissa() == 0) ? 0 : (isNegative() ? -1 : 1); } + + uint16_t exponent() const { + return (x_uint >> (Traits::mantissa_bits)) & + (((1ull << (Traits::exponent_bits + 1)) - 1) >> 1); + } + + int16_t normalizedExponent() const { + return int16_t(exponent()) - ((1ull << (Traits::exponent_bits - 1)) - 1); + } + + uint64_t mantissa() const { return x_uint & ((1ull << Traits::mantissa_bits) - 1); } + + int64_t mantissaWithSign() const { return isNegative() ? -mantissa() : mantissa(); } + + /// NOTE Probably floating point instructions can be better. + bool isIntegerInRepresentableRange() const { + return x_uint == 0 || + (normalizedExponent() >= 0 /// The number is not less than one + /// The number is inside the range where every integer has exact representation in float + && normalizedExponent() <= static_cast(Traits::mantissa_bits) + /// After multiplying by 2^exp, the fractional part becomes zero, means the number is integer + && ((mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == + 0)); + } + + /// Compare float with integer of arbitrary width (both signed and unsigned are supported). Assuming two's complement arithmetic. + /// This function is generic, big integers (128, 256 bit) are supported as well. + /// Infinities are compared correctly. NaNs are treat similarly to infinities, so they can be less than all numbers. + /// (note that we need total order) + /// Returns -1, 0 or 1. + template + int compare(Int rhs) const { + if (rhs == 0) return sign(); + + /// Different signs + if (isNegative() && rhs > 0) return -1; + if (!isNegative() && rhs < 0) return 1; + + /// Fractional number with magnitude less than one + if (normalizedExponent() < 0) { + if (!isNegative()) + return rhs > 0 ? -1 : 1; + else + return rhs >= 0 ? -1 : 1; + } + + /// The case of the most negative integer + if constexpr (is_signed_v) { + if (rhs == std::numeric_limits::lowest()) { + assert(isNegative()); + + if (normalizedExponent() < static_cast(8 * sizeof(Int) - is_signed_v)) + return 1; + if (normalizedExponent() > static_cast(8 * sizeof(Int) - is_signed_v)) + return -1; + + if (mantissa() == 0) + return 0; + else + return -1; + } + } + + /// Too large number: abs(float) > abs(rhs). Also the case with infinities and NaN. + if (normalizedExponent() >= static_cast(8 * sizeof(Int) - is_signed_v)) + return isNegative() ? -1 : 1; + + using UInt = std::conditional_t<(sizeof(Int) > sizeof(typename Traits::UInt)), + make_unsigned_t, typename Traits::UInt>; + UInt uint_rhs = rhs < 0 ? -rhs : rhs; + + /// Smaller octave: abs(rhs) < abs(float) + /// FYI, TIL: octave is also called "binade", https://en.wikipedia.org/wiki/Binade + if (uint_rhs < (static_cast(1) << normalizedExponent())) return isNegative() ? -1 : 1; + + /// Larger octave: abs(rhs) > abs(float) + if (normalizedExponent() + 1 < static_cast(8 * sizeof(Int) - is_signed_v) && + uint_rhs >= (static_cast(1) << (normalizedExponent() + 1))) + return isNegative() ? 1 : -1; + + /// The same octave + /// uint_rhs == 2 ^ normalizedExponent + mantissa * 2 ^ (normalizedExponent - mantissa_bits) + + bool large_and_always_integer = + normalizedExponent() >= static_cast(Traits::mantissa_bits); + + UInt a = large_and_always_integer + ? static_cast(mantissa()) + << (normalizedExponent() - Traits::mantissa_bits) + : static_cast(mantissa()) >> + (Traits::mantissa_bits - normalizedExponent()); + + UInt b = uint_rhs - (static_cast(1) << normalizedExponent()); + + if (a < b) return isNegative() ? 1 : -1; + if (a > b) return isNegative() ? -1 : 1; + + /// Float has no fractional part means that the numbers are equal. + if (large_and_always_integer || + (mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0) + return 0; + else + /// Float has fractional part means its abs value is larger. + return isNegative() ? -1 : 1; + } + + template + bool equals(Int rhs) const { + return compare(rhs) == 0; + } + + template + bool notEquals(Int rhs) const { + return compare(rhs) != 0; + } + + template + bool less(Int rhs) const { + return compare(rhs) < 0; + } + + template + bool greater(Int rhs) const { + return compare(rhs) > 0; + } + + template + bool lessOrEquals(Int rhs) const { + return compare(rhs) <= 0; + } + + template + bool greaterOrEquals(Int rhs) const { + return compare(rhs) >= 0; + } +}; + +using DecomposedFloat64 = DecomposedFloat; +using DecomposedFloat32 = DecomposedFloat; diff --git a/be/src/vec/core/extended_types.h b/be/src/vec/core/extended_types.h new file mode 100644 index 000000000000000..3c67ebb50e29425 --- /dev/null +++ b/be/src/vec/core/extended_types.h @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/base/base/extended_types.h +// and modified by Doris +#pragma once + +#include + +#include "wide_integer.h" + +// using Int128 = wide::integer<128, signed>; +// using UInt128 = wide::integer<128, unsigned>; +using Int256 = wide::integer<256, signed>; +// using UInt256 = wide::integer<256, unsigned>; + +static_assert(sizeof(Int256) == 32); +// static_assert(sizeof(UInt256) == 32); + +/// The standard library type traits, such as std::is_arithmetic, with one exception +/// (std::common_type), are "set in stone". Attempting to specialize them causes undefined behavior. +/// So instead of using the std type_traits, we use our own version which allows extension. +template +struct is_signed // NOLINT(readability-identifier-naming) +{ + static constexpr bool value = std::is_signed_v; +}; + +// template <> struct is_signed { static constexpr bool value = true; }; +template <> +struct is_signed { + static constexpr bool value = true; +}; + +template +inline constexpr bool is_signed_v = is_signed::value; + +template +struct is_unsigned // NOLINT(readability-identifier-naming) +{ + static constexpr bool value = std::is_unsigned_v; +}; + +// template <> struct is_unsigned { static constexpr bool value = true; }; +// template <> struct is_unsigned { static constexpr bool value = true; }; + +template +inline constexpr bool is_unsigned_v = is_unsigned::value; + +template +concept is_integer = std::is_integral_v + // || std::is_same_v + // || std::is_same_v + || std::is_same_v; +// || std::is_same_v; + +template +concept is_floating_point = std::is_floating_point_v; + +// template +// struct is_arithmetic // NOLINT(readability-identifier-naming) +// { +// static constexpr bool value = std::is_arithmetic_v; +// }; + +// template <> struct is_arithmetic { static constexpr bool value = true; }; +// template <> struct is_arithmetic { static constexpr bool value = true; }; +// template <> struct is_arithmetic { static constexpr bool value = true; }; +// template <> struct is_arithmetic { static constexpr bool value = true; }; + +// template +// inline constexpr bool is_arithmetic_v = is_arithmetic::value; +// +template +struct make_unsigned // NOLINT(readability-identifier-naming) +{ + using type = std::make_unsigned_t; +}; + +// template <> struct make_unsigned { using type = UInt128; }; +// template <> struct make_unsigned { using type = UInt128; }; +// template <> struct make_unsigned { using type = UInt256; }; +// template <> struct make_unsigned { using type = UInt256; }; + +template +using make_unsigned_t = typename make_unsigned::type; + +// template +// struct make_signed // NOLINT(readability-identifier-naming) +// { +// using type = std::make_signed_t; +// }; +// +// template <> struct make_signed { using type = Int128; }; +// template <> struct make_signed { using type = Int128; }; +// template <> struct make_signed { using type = Int256; }; +// template <> struct make_signed { using type = Int256; }; +// +// template using make_signed_t = typename make_signed::type; +// +// template +// struct is_big_int // NOLINT(readability-identifier-naming) +// { +// static constexpr bool value = false; +// }; +// +// template <> struct is_big_int { static constexpr bool value = true; }; +// template <> struct is_big_int { static constexpr bool value = true; }; +// template <> struct is_big_int { static constexpr bool value = true; }; +// template <> struct is_big_int { static constexpr bool value = true; }; +// +// template +// inline constexpr bool is_big_int_v = is_big_int::value; diff --git a/be/src/vec/core/field.cpp b/be/src/vec/core/field.cpp index 9970b284ceb4b06..337c2c395f201d7 100644 --- a/be/src/vec/core/field.cpp +++ b/be/src/vec/core/field.cpp @@ -170,6 +170,7 @@ bool dec_less_or_equal(T x, T y, UInt32 x_scale, UInt32 y_scale) { DECLARE_DECIMAL_COMPARISON(Decimal32) DECLARE_DECIMAL_COMPARISON(Decimal64) DECLARE_DECIMAL_COMPARISON(Decimal128) +DECLARE_DECIMAL_COMPARISON(Decimal256) template <> bool decimal_equal(Decimal128I x, Decimal128I y, UInt32 xs, UInt32 ys) { diff --git a/be/src/vec/core/field.h b/be/src/vec/core/field.h index 9aadfe2a0ae2d89..1912057ab2b2128 100644 --- a/be/src/vec/core/field.h +++ b/be/src/vec/core/field.h @@ -65,7 +65,7 @@ struct NearestFieldTypeImpl { template using NearestFieldType = typename NearestFieldTypeImpl::Type; -template +template struct AvgNearestFieldTypeTrait { using Type = typename NearestFieldTypeImpl::Type; }; @@ -90,11 +90,31 @@ struct AvgNearestFieldTypeTrait { using Type = Decimal128; }; +template <> +struct AvgNearestFieldTypeTrait { + using Type = Decimal256; +}; + template <> struct AvgNearestFieldTypeTrait { using Type = double; }; +template <> +struct AvgNearestFieldTypeTrait { + using Type = Decimal256; +}; + +template <> +struct AvgNearestFieldTypeTrait { + using Type = Decimal256; +}; + +template <> +struct AvgNearestFieldTypeTrait { + using Type = Decimal256; +}; + class Field; using FieldVector = std::vector; @@ -319,6 +339,8 @@ class Field { Bitmap = 27, HyperLogLog = 28, QuantileState = 29, + Int256 = 30, + Decimal256 = 31, }; static const int MIN_NON_POD = 16; @@ -355,6 +377,8 @@ class Field { return "Decimal128"; case Decimal128I: return "Decimal128I"; + case Decimal256: + return "Decimal256"; case FixedLengthObject: return "FixedLengthObject"; case VariantMap: @@ -380,7 +404,7 @@ class Field { static bool is_decimal(Types::Which which) { return (which >= Types::Decimal32 && which <= Types::Decimal128) || - which == Types::Decimal128I; + which == Types::Decimal128I || which == Types::Decimal256; } Field() : which(Types::Null) {} @@ -551,6 +575,8 @@ class Field { return get() <=> rhs.get(); case Types::Decimal128I: return get() <=> rhs.get(); + case Types::Decimal256: + return get() <=> rhs.get(); default: LOG(FATAL) << "lhs type not equal with rhs, lhs=" << Types::to_string(which) << ", rhs=" << Types::to_string(rhs.which); @@ -562,7 +588,8 @@ class Field { std::aligned_union_t, DecimalField, DecimalField, - DecimalField, BitmapValue, HyperLogLog, QuantileState> + DecimalField, DecimalField, BitmapValue, + HyperLogLog, QuantileState> storage; Types::Which which; @@ -640,6 +667,9 @@ class Field { case Types::Decimal128I: f(field.template get>()); return; + case Types::Decimal256: + f(field.template get>()); + return; case Types::VariantMap: f(field.template get()); return; @@ -753,6 +783,10 @@ struct TypeId> { static constexpr const TypeIndex value = TypeIndex::Decimal128I; }; template <> +struct TypeId> { + static constexpr const TypeIndex value = TypeIndex::Decimal256; +}; +template <> struct Field::TypeToEnum { static constexpr Types::Which value = Types::Null; }; @@ -773,6 +807,10 @@ struct Field::TypeToEnum { static constexpr Types::Which value = Types::Int128; }; template <> +struct Field::TypeToEnum { + static constexpr Types::Which value = Types::Int256; +}; +template <> struct Field::TypeToEnum { static constexpr Types::Which value = Types::Float64; }; @@ -813,6 +851,10 @@ struct Field::TypeToEnum> { static constexpr Types::Which value = Types::Decimal128I; }; template <> +struct Field::TypeToEnum> { + static constexpr Types::Which value = Types::Decimal256; +}; +template <> struct Field::TypeToEnum { static constexpr Types::Which value = Types::VariantMap; }; @@ -893,6 +935,10 @@ struct Field::EnumToType { using Type = DecimalField; }; template <> +struct Field::EnumToType { + using Type = DecimalField; +}; +template <> struct Field::EnumToType { using Type = VariantMap; }; @@ -993,6 +1039,10 @@ struct NearestFieldTypeImpl { using Type = DecimalField; }; template <> +struct NearestFieldTypeImpl { + using Type = DecimalField; +}; +template <> struct NearestFieldTypeImpl> { using Type = DecimalField; }; @@ -1009,6 +1059,10 @@ struct NearestFieldTypeImpl> { using Type = DecimalField; }; template <> +struct NearestFieldTypeImpl> { + using Type = DecimalField; +}; +template <> struct NearestFieldTypeImpl { using Type = Float64; }; diff --git a/be/src/vec/core/types.h b/be/src/vec/core/types.h index abb5c9255c641f1..6c4ecdad7814812 100644 --- a/be/src/vec/core/types.h +++ b/be/src/vec/core/types.h @@ -28,6 +28,10 @@ #include "common/consts.h" #include "util/binary_cast.hpp" #include "vec/common/int_exp.h" +#include "vec/core/wide_integer.h" +#include "vec/core/wide_integer_to_string.h" + +using wide::Int256; namespace doris { @@ -92,7 +96,9 @@ enum class TypeIndex { VARIANT = 41, QuantileState = 42, Time = 43, - AggState + AggState = 44, + Decimal256 = 45, + Int256 }; struct Consted { @@ -277,10 +283,21 @@ struct TypeName { static const char* get() { return "Int128"; } }; template <> +inline constexpr bool IsNumber = true; +template <> +struct TypeName { + static const char* get() { return "Int256"; } +}; +template <> struct TypeId { static constexpr const TypeIndex value = TypeIndex::Int128; }; +template <> +struct TypeId { + static constexpr const TypeIndex value = TypeIndex::Int256; +}; + using Date = Int64; using DateTime = Int64; using DateV2 = UInt32; @@ -300,6 +317,10 @@ template <> inline constexpr Int128 decimal_scale_multiplier(UInt32 scale) { return common::exp10_i128(scale); } +template <> +inline constexpr Int256 decimal_scale_multiplier(UInt32 scale) { + return common::exp10_i256(scale); +} /// Own FieldType for Decimal. /// It is only a "storage" for decimal. To perform operations, you also have to provide a scale (number of digits after point). @@ -314,6 +335,7 @@ struct Decimal { #define DECLARE_NUMERIC_CTOR(TYPE) \ Decimal(const TYPE& value_) : value(value_) {} + DECLARE_NUMERIC_CTOR(Int256) DECLARE_NUMERIC_CTOR(Int128) DECLARE_NUMERIC_CTOR(Int32) DECLARE_NUMERIC_CTOR(Int64) @@ -348,6 +370,12 @@ struct Decimal { operator T() const { return value; } + operator wide::Int256() const { + wide::Int256 result; + wide::Int256::_impl::wide_integer_from_builtin(result, value); + return result; + } + const Decimal& operator++() { value++; return *this; @@ -384,8 +412,11 @@ struct Decimal { constexpr auto precision = std::is_same_v ? BeConsts::MAX_DECIMAL32_PRECISION - : (std::is_same_v ? BeConsts::MAX_DECIMAL64_PRECISION - : BeConsts::MAX_DECIMAL128_PRECISION); + : (std::is_same_v + ? BeConsts::MAX_DECIMAL64_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL128_PRECISION + : BeConsts::MAX_DECIMAL256_PRECISION)); return precision + 1 // Add a space for decimal place + 1 // Add a space for leading 0 + 1; // Add a space for negative sign @@ -393,18 +424,27 @@ struct Decimal { std::string to_string(UInt32 scale) const { if (value == std::numeric_limits::min()) { - fmt::memory_buffer buffer; - fmt::format_to(buffer, "{}", value); - std::string res {buffer.data(), buffer.size()}; - res.insert(res.size() - scale, "."); - return res; + if constexpr (std::is_same_v) { + std::string res {wide::to_string(value)}; + res.insert(res.size() - scale, "."); + return res; + } else { + fmt::memory_buffer buffer; + fmt::format_to(buffer, "{}", value); + std::string res {buffer.data(), buffer.size()}; + res.insert(res.size() - scale, "."); + return res; + } } static constexpr auto precision = std::is_same_v ? BeConsts::MAX_DECIMAL32_PRECISION - : (std::is_same_v ? BeConsts::MAX_DECIMAL64_PRECISION - : BeConsts::MAX_DECIMAL128_PRECISION); + : (std::is_same_v + ? BeConsts::MAX_DECIMAL64_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL128_PRECISION + : BeConsts::MAX_DECIMAL256_PRECISION)); bool is_nagetive = value < 0; int max_result_length = precision + (scale > 0) // Add a space for decimal place + (scale == precision) // Add a space for leading 0 @@ -425,14 +465,20 @@ struct Decimal { whole_part = abs_value / decimal_scale_multiplier(scale); frac_part = abs_value % decimal_scale_multiplier(scale); } - auto end = fmt::format_to(str.data() + pos, "{}", whole_part); - pos = end - str.data(); + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(whole_part)}; + auto end = fmt::format_to(str.data() + pos, "{}", num_str); + pos = end - str.data(); + } else { + auto end = fmt::format_to(str.data() + pos, "{}", whole_part); + pos = end - str.data(); + } if (scale) { str[pos++] = '.'; for (auto end_pos = pos + scale - 1; end_pos >= pos && frac_part > 0; --end_pos, frac_part /= 10) { - str[end_pos] += frac_part % 10; + str[end_pos] += (int)(frac_part % 10); } } @@ -450,8 +496,15 @@ struct Decimal { __attribute__((always_inline)) size_t to_string(char* dst, UInt32 scale, const T& scale_multiplier) const { if (UNLIKELY(value == std::numeric_limits::min())) { - auto end = fmt::format_to(dst, "{}", value); - return end - dst; + if constexpr (std::is_same_v) { + // handle scale? + std::string num_str {wide::to_string(value)}; + auto end = fmt::format_to(dst, "{}", num_str); + return end - dst; + } else { + auto end = fmt::format_to(dst, "{}", value); + return end - dst; + } } bool is_negative = value < 0; @@ -469,8 +522,14 @@ struct Decimal { whole_part = abs_value / scale_multiplier; frac_part = abs_value % scale_multiplier; } - auto end = fmt::format_to(dst + pos, "{}", whole_part); - pos = end - dst; + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(whole_part)}; + auto end = fmt::format_to(dst + pos, "{}", num_str); + pos = end - dst; + } else { + auto end = fmt::format_to(dst + pos, "{}", whole_part); + pos = end - dst; + } if (LIKELY(scale)) { int low_scale = 0; @@ -490,8 +549,14 @@ struct Decimal { pos += scale - low_scale; } if (frac_part) { - end = fmt::format_to(&dst[pos], "{}", frac_part); - pos = end - dst; + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(whole_part)}; + auto end = fmt::format_to(&dst[pos], "{}", num_str); + pos = end - dst; + } else { + auto end = fmt::format_to(&dst[pos], "{}", frac_part); + pos = end - dst; + } } } @@ -507,6 +572,7 @@ struct Decimal128I : public Decimal { #define DECLARE_NUMERIC_CTOR(TYPE) \ Decimal128I(const TYPE& value_) : Decimal(value_) {} + DECLARE_NUMERIC_CTOR(Int256) DECLARE_NUMERIC_CTOR(Int128) DECLARE_NUMERIC_CTOR(Int32) DECLARE_NUMERIC_CTOR(Int64) @@ -522,9 +588,291 @@ struct Decimal128I : public Decimal { } }; +template <> +struct Decimal { + using T = Int256; + using NativeType = Int256; + + Decimal() = default; + Decimal(Decimal&&) = default; + Decimal(const Decimal&) = default; + +#define DECLARE_NUMERIC_CTOR(TYPE) \ + explicit Decimal(const TYPE& value_) : value(value_) {} + + DECLARE_NUMERIC_CTOR(Int256) + DECLARE_NUMERIC_CTOR(Int128) + DECLARE_NUMERIC_CTOR(Int32) + DECLARE_NUMERIC_CTOR(Int64) + DECLARE_NUMERIC_CTOR(UInt32) + DECLARE_NUMERIC_CTOR(UInt64) + +#undef DECLARE_NUMERIC_CTOR + + explicit Decimal(const Float32& value_) : value(value_) { + if constexpr (std::is_integral::value) { + value = round(value_); + } + } + explicit Decimal(const Float64& value_) : value(value_) { + if constexpr (std::is_integral::value) { + value = round(value_); + } + } + + static Decimal double_to_decimal(double value_) { + DecimalV2Value decimal_value; + decimal_value.assign_from_double(value_); + return Decimal(binary_cast(decimal_value)); + } + + template + explicit Decimal(const Decimal& x) { + value = x.value; + } + + constexpr Decimal& operator=(Decimal&&) = default; + constexpr Decimal& operator=(const Decimal&) = default; + + operator T() const { return value; } + + operator Int128() const { return (Int128)value.items[0] + ((Int128)(value.items[1]) << 64); } + + const Decimal& operator++() { + value++; + return *this; + } + const Decimal& operator--() { + value--; + return *this; + } + + const Decimal& operator+=(const T& x) { + value += x; + return *this; + } + const Decimal& operator-=(const T& x) { + value -= x; + return *this; + } + const Decimal& operator*=(const T& x) { + value *= x; + return *this; + } + const Decimal& operator/=(const T& x) { + value /= x; + return *this; + } + const Decimal& operator%=(const T& x) { + value %= x; + return *this; + } + + auto operator<=>(const Decimal& x) const { return value <=> x.value; } + + static constexpr int max_string_length() { + constexpr auto precision = + std::is_same_v + ? BeConsts::MAX_DECIMAL32_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL64_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL128_PRECISION + : BeConsts::MAX_DECIMAL256_PRECISION)); + return precision + 1 // Add a space for decimal place + + 1 // Add a space for leading 0 + + 1; // Add a space for negative sign + } + + std::string to_string(UInt32 scale) const { + if (value == std::numeric_limits::min()) { + if constexpr (std::is_same_v) { + std::string res {wide::to_string(value)}; + res.insert(res.size() - scale, "."); + return res; + } else { + fmt::memory_buffer buffer; + fmt::format_to(buffer, "{}", value); + std::string res {buffer.data(), buffer.size()}; + res.insert(res.size() - scale, "."); + return res; + } + } + + static constexpr auto precision = + std::is_same_v + ? BeConsts::MAX_DECIMAL32_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL64_PRECISION + : (std::is_same_v + ? BeConsts::MAX_DECIMAL128_PRECISION + : BeConsts::MAX_DECIMAL256_PRECISION)); + bool is_nagetive = value < 0; + int max_result_length = precision + (scale > 0) // Add a space for decimal place + + (scale == precision) // Add a space for leading 0 + + (is_nagetive); // Add a space for negative sign + std::string str = std::string(max_result_length, '0'); + + T abs_value = value; + int pos = 0; + + if (is_nagetive) { + abs_value = -value; + str[pos++] = '-'; + } + + T whole_part = abs_value; + T frac_part; + if (scale) { + whole_part = abs_value / decimal_scale_multiplier(scale); + frac_part = abs_value % decimal_scale_multiplier(scale); + } + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(whole_part)}; + auto end = fmt::format_to(str.data() + pos, "{}", num_str); + pos = end - str.data(); + } else { + auto end = fmt::format_to(str.data() + pos, "{}", whole_part); + pos = end - str.data(); + } + + if (scale) { + str[pos++] = '.'; + for (auto end_pos = pos + scale - 1; end_pos >= pos && frac_part > 0; + --end_pos, frac_part /= 10) { + str[end_pos] += (int)(frac_part % 10); + } + } + + str.resize(pos + scale); + return str; + } + + /** + * Got the string representation of a decimal. + * @param dst Store the result, should be pre-allocated. + * @param scale Decimal's scale. + * @param scale_multiplier Decimal's scale multiplier. + * @return The length of string. + */ + __attribute__((always_inline)) size_t to_string(char* dst, UInt32 scale, + const T& scale_multiplier) const { + if (UNLIKELY(value == std::numeric_limits::min())) { + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(value)}; + auto end = fmt::format_to(dst, "{}", num_str); + return end - dst; + } else { + auto end = fmt::format_to(dst, "{}", value); + return end - dst; + } + } + + bool is_negative = value < 0; + T abs_value = value; + int pos = 0; + + if (is_negative) { + abs_value = -value; + dst[pos++] = '-'; + } + + T whole_part = abs_value; + T frac_part; + if (LIKELY(scale)) { + whole_part = abs_value / scale_multiplier; + frac_part = abs_value % scale_multiplier; + } + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(whole_part)}; + auto end = fmt::format_to(dst + pos, "{}", num_str); + pos = end - dst; + } else { + auto end = fmt::format_to(dst + pos, "{}", whole_part); + pos = end - dst; + } + + if (LIKELY(scale)) { + int low_scale = 0; + int high_scale = scale; + while (low_scale < high_scale) { + int mid_scale = (high_scale + low_scale) >> 1; + const auto mid_scale_factor = decimal_scale_multiplier(mid_scale); + if (mid_scale_factor <= frac_part) { + low_scale = mid_scale + 1; + } else { + high_scale = mid_scale; + } + } + dst[pos++] = '.'; + if (low_scale < scale) { + memset(&dst[pos], '0', scale - low_scale); + pos += scale - low_scale; + } + if (frac_part) { + if constexpr (std::is_same_v) { + std::string num_str {wide::to_string(frac_part)}; + auto end = fmt::format_to(dst + pos, "{}", num_str); + pos = end - dst; + } else { + auto end = fmt::format_to(&dst[pos], "{}", frac_part); + pos = end - dst; + } + } + } + + return pos; + } + + T value; +}; + using Decimal32 = Decimal; using Decimal64 = Decimal; using Decimal128 = Decimal; +using Decimal256 = Decimal; +template +inline Decimal operator-(const Decimal& x) { + return -x.value; +} + +inline Decimal256 operator+(const Decimal256& x, const Decimal256& y) { + return Decimal256(x.value + y.value); +} +inline Decimal256 operator-(const Decimal256& x, const Decimal256& y) { + return Decimal256(x.value - y.value); +} +inline Decimal256 operator*(const Decimal256& x, const Decimal256& y) { + return Decimal256(x.value * y.value); +} +inline Decimal256 operator/(const Decimal256& x, const Decimal256& y) { + return Decimal256(x.value / y.value); +} +inline Decimal256 operator%(const Decimal256& x, const Decimal256& y) { + return Decimal256(x.value % y.value); +} +inline Decimal256 operator-(const Decimal256& x) { + return Decimal256(-x.value); +} + +inline bool operator<(const Decimal256& x, const Decimal256& y) { + return x.value < y.value; +} +inline bool operator>(const Decimal256& x, const Decimal256& y) { + return x.value > y.value; +} +inline bool operator<=(const Decimal256& x, const Decimal256& y) { + return x.value <= y.value; +} +inline bool operator>=(const Decimal256& x, const Decimal256& y) { + return x.value >= y.value; +} +inline bool operator==(const Decimal256& x, const Decimal256& y) { + return x.value == y.value; +} +inline bool operator!=(const Decimal256& x, const Decimal256& y) { + return x.value != y.value; +} template <> struct TypeName { @@ -543,6 +891,11 @@ struct TypeName { static const char* get() { return "Decimal128I"; } }; +template <> +struct TypeName { + static const char* get() { return "Decimal256"; } +}; + template <> struct TypeId { static constexpr const TypeIndex value = TypeIndex::Decimal32; @@ -559,6 +912,10 @@ template <> struct TypeId { static constexpr const TypeIndex value = TypeIndex::Decimal128I; }; +template <> +struct TypeId { + static constexpr const TypeIndex value = TypeIndex::Decimal256; +}; template constexpr bool IsDecimalNumber = false; @@ -570,6 +927,8 @@ template <> inline constexpr bool IsDecimalNumber = true; template <> inline constexpr bool IsDecimalNumber = true; +template <> +inline constexpr bool IsDecimalNumber = true; template constexpr bool IsDecimal128 = false; @@ -581,6 +940,11 @@ constexpr bool IsDecimal128I = false; template <> inline constexpr bool IsDecimal128I = true; +template +constexpr bool IsDecimal256 = false; +template <> +inline constexpr bool IsDecimal256 = true; + template constexpr bool IsDecimalV2 = IsDecimal128 && !IsDecimal128I; @@ -588,6 +952,10 @@ template using DisposeDecimal = std::conditional_t, Decimal128, std::conditional_t, Decimal128I, U>>; +template +using DisposeDecimal256 = std::conditional_t, Decimal128, + std::conditional_t, Decimal256, U>>; + template constexpr bool IsFloatNumber = false; template <> @@ -615,6 +983,10 @@ template <> struct NativeType { using Type = Int128; }; +template <> +struct NativeType { + using Type = Int256; +}; inline const char* getTypeName(TypeIndex idx) { switch (idx) { @@ -640,6 +1012,8 @@ inline const char* getTypeName(TypeIndex idx) { return TypeName::get(); case TypeIndex::Int128: return TypeName::get(); + case TypeIndex::Int256: + return TypeName::get(); case TypeIndex::Float32: return TypeName::get(); case TypeIndex::Float64: @@ -670,6 +1044,8 @@ inline const char* getTypeName(TypeIndex idx) { return TypeName::get(); case TypeIndex::Decimal128I: return TypeName::get(); + case TypeIndex::Decimal256: + return TypeName::get(); case TypeIndex::UUID: return "UUID"; case TypeIndex::Array: @@ -740,6 +1116,15 @@ struct std::hash { } }; +template <> +struct std::hash { + size_t operator()(const doris::vectorized::Decimal256& x) const { + return std::hash()(x.value >> 192) ^ std::hash()(x.value >> 128) ^ + std::hash()(x.value >> 64) ^ + std::hash()(x.value & std::numeric_limits::max()); + } +}; + constexpr bool typeindex_is_int(doris::vectorized::TypeIndex index) { using TypeIndex = doris::vectorized::TypeIndex; switch (index) { diff --git a/be/src/vec/core/wide_integer.h b/be/src/vec/core/wide_integer.h new file mode 100644 index 000000000000000..e7902e414a854f8 --- /dev/null +++ b/be/src/vec/core/wide_integer.h @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +/////////////////////////////////////////////////////////////// +// Distributed under the Boost Software License, Version 1.0. +// (See at http://www.boost.org/LICENSE_1_0.txt) +/////////////////////////////////////////////////////////////// + +/* Divide and multiply + * + * + * Copyright (c) 2008 + * Evan Teran + * + * Permission to use, copy, modify, and distribute this software and its + * documentation for any purpose and without fee is hereby granted, provided + * that the above copyright notice appears in all copies and that both the + * copyright notice and this permission notice appear in supporting + * documentation, and that the same name not be used in advertising or + * publicity pertaining to distribution of the software without specific, + * written prior permission. We make no representations about the + * suitability this software for any purpose. It is provided "as is" + * without express or implied warranty. + */ +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/base/base/wide_integer.h +// and modified by Doris +#pragma once + +#include +#include +#include +#include + +// NOLINTBEGIN(*) + +namespace wide { +template +class integer; +} + +namespace std { + +template +struct common_type, wide::integer>; + +template +struct common_type, Arithmetic>; + +template +struct common_type>; + +} // namespace std + +namespace wide { + +template +class integer { +public: + using base_type = uint64_t; + using signed_base_type = int64_t; + + // ctors + constexpr integer() noexcept = default; + + template + constexpr integer(T rhs) noexcept; + + template + constexpr integer(std::initializer_list il) noexcept; + + // assignment + template + constexpr integer& operator=(const integer& rhs) noexcept; + + template + constexpr integer& operator=(Arithmetic rhs) noexcept; + + template + constexpr integer& operator*=(const Arithmetic& rhs); + + template + constexpr integer& operator/=(const Arithmetic& rhs); + + template + constexpr integer& operator+=(const Arithmetic& rhs) noexcept( + std::is_same_v); + + template + constexpr integer& operator-=(const Arithmetic& rhs) noexcept( + std::is_same_v); + + template + constexpr integer& operator%=(const Integral& rhs); + + template + constexpr integer& operator&=(const Integral& rhs) noexcept; + + template + constexpr integer& operator|=(const Integral& rhs) noexcept; + + template + constexpr integer& operator^=(const Integral& rhs) noexcept; + + constexpr integer& operator<<=(int n) noexcept; + constexpr integer& operator>>=(int n) noexcept; + + constexpr integer& operator++() noexcept(std::is_same_v); + constexpr integer operator++(int) noexcept(std::is_same_v); + constexpr integer& operator--() noexcept(std::is_same_v); + constexpr integer operator--(int) noexcept(std::is_same_v); + + // observers + + constexpr explicit operator bool() const noexcept; + + template , T>> + constexpr operator T() const noexcept; + + constexpr operator long double() const noexcept; + constexpr operator double() const noexcept; + constexpr operator float() const noexcept; + + struct _impl; + + base_type items[_impl::item_count]; + +private: + template + friend class integer; + + friend class std::numeric_limits>; + friend class std::numeric_limits>; +}; + +using Int256 = integer<256, signed>; +using UInt256 = integer<256, unsigned>; + +template +static constexpr bool ArithmeticConcept() noexcept; + +template +using _only_arithmetic = + typename std::enable_if() && ArithmeticConcept()>::type; + +template +static constexpr bool IntegralConcept() noexcept; + +template +using _only_integer = typename std::enable_if() && IntegralConcept()>::type; + +// Unary operators +template +constexpr integer operator~(const integer& lhs) noexcept; + +template +constexpr integer operator-(const integer& lhs) noexcept( + std::is_same_v); + +template +constexpr integer operator+(const integer& lhs) noexcept( + std::is_same_v); + +// Binary operators +template +std::common_type_t, integer> constexpr operator*( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator*(const Arithmetic& rhs, + const Arithmetic2& lhs); + +template +std::common_type_t, integer> constexpr operator/( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator/(const Arithmetic& rhs, + const Arithmetic2& lhs); + +template +std::common_type_t, integer> constexpr operator+( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator+(const Arithmetic& rhs, + const Arithmetic2& lhs); + +template +std::common_type_t, integer> constexpr operator-( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator-(const Arithmetic& rhs, + const Arithmetic2& lhs); + +template +std::common_type_t, integer> constexpr operator%( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator%(const Integral& rhs, + const Integral2& lhs); + +template +std::common_type_t, integer> constexpr operator&( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator&(const Integral& rhs, + const Integral2& lhs); + +template +std::common_type_t, integer> constexpr operator|( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator|(const Integral& rhs, + const Integral2& lhs); + +template +std::common_type_t, integer> constexpr operator^( + const integer& lhs, const integer& rhs); +template > +std::common_type_t constexpr operator^(const Integral& rhs, + const Integral2& lhs); + +// TODO: Integral +template +constexpr integer operator<<(const integer& lhs, int n) noexcept; + +template +constexpr integer operator>>(const integer& lhs, int n) noexcept; + +template >> +constexpr integer operator<<(const integer& lhs, Int n) noexcept { + return lhs << int(n); +} +template >> +constexpr integer operator>>(const integer& lhs, Int n) noexcept { + return lhs >> int(n); +} + +template +constexpr bool operator<(const integer& lhs, const integer& rhs); +template > +constexpr bool operator<(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr bool operator>(const integer& lhs, const integer& rhs); +template > +constexpr bool operator>(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr bool operator<=(const integer& lhs, const integer& rhs); +template > +constexpr bool operator<=(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr bool operator>=(const integer& lhs, const integer& rhs); +template > +constexpr bool operator>=(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr bool operator==(const integer& lhs, const integer& rhs); +template > +constexpr bool operator==(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr bool operator!=(const integer& lhs, const integer& rhs); +template > +constexpr bool operator!=(const Arithmetic& rhs, const Arithmetic2& lhs); + +template +constexpr auto operator<=>(const integer& lhs, const integer& rhs); +template > +constexpr auto operator<=>(const Arithmetic& rhs, const Arithmetic2& lhs); + +} // namespace wide + +// NOLINTEND(*) + +#include "wide_integer_impl.h" diff --git a/be/src/vec/core/wide_integer_impl.h b/be/src/vec/core/wide_integer_impl.h new file mode 100644 index 000000000000000..29ad81873fc84f3 --- /dev/null +++ b/be/src/vec/core/wide_integer_impl.h @@ -0,0 +1,1389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +/// Original is here https://github.com/cerevra/int +/// Distributed under the Boost Software License, Version 1.0. +/// (See at http://www.boost.org/LICENSE_1_0.txt) + +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/base/base/wide_integer_impl.h +// and modified by Doris +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "common/exception.h" + +// NOLINTBEGIN(*) + +/// Use same extended double for all platforms +#if (LDBL_MANT_DIG == 64) +#define CONSTEXPR_FROM_DOUBLE constexpr +using FromDoubleIntermediateType = long double; +#else +#include +/// `wide_integer_from_builtin` can't be constexpr with non-literal `cpp_bin_float_double_extended` +#define CONSTEXPR_FROM_DOUBLE +using FromDoubleIntermediateType = boost::multiprecision::cpp_bin_float_double_extended; +#endif + +namespace CityHash_v1_0_2 { +struct uint128; +} + +namespace wide { + +template +struct IsWideInteger { + static const constexpr bool value = false; +}; + +template +struct IsWideInteger> { + static const constexpr bool value = true; +}; + +template +static constexpr bool ArithmeticConcept() noexcept { + return std::is_arithmetic_v || IsWideInteger::value; +} + +template +static constexpr bool IntegralConcept() noexcept { + return std::is_integral_v || IsWideInteger::value; +} + +template +class IsTupleLike { + template + static auto check(U* p) -> decltype(std::tuple_size::value, int()); + template + static void check(...); + +public: + static constexpr const bool value = !std::is_void(nullptr))>::value; +}; + +} // namespace wide + +namespace std { + +// numeric limits +template +class numeric_limits> { +public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = is_same::value; + static constexpr bool is_integer = true; + static constexpr bool is_exact = true; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = true; + static constexpr std::float_denorm_style has_denorm = std::denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr std::float_round_style round_style = std::round_toward_zero; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = true; + static constexpr int digits = Bits - (is_same::value ? 1 : 0); + static constexpr int digits10 = digits * 0.30103 /*std::log10(2)*/; + static constexpr int max_digits10 = 0; + static constexpr int radix = 2; + static constexpr int min_exponent = 0; + static constexpr int min_exponent10 = 0; + static constexpr int max_exponent = 0; + static constexpr int max_exponent10 = 0; + static constexpr bool traps = true; + static constexpr bool tinyness_before = false; + + static constexpr wide::integer min() noexcept { + if (is_same::value) { + using T = wide::integer; + T res {}; + res.items[T::_impl::big(0)] = std::numeric_limits< + typename wide::integer::signed_base_type>::min(); + return res; + } + return wide::integer(0); + } + + static constexpr wide::integer max() noexcept { + using T = wide::integer; + T res {}; + res.items[T::_impl::big(0)] = + is_same::value + ? std::numeric_limits< + typename wide::integer::signed_base_type>::max() + : std::numeric_limits< + typename wide::integer::base_type>::max(); + for (unsigned i = 1; i < wide::integer::_impl::item_count; ++i) { + res.items[T::_impl::big(i)] = + std::numeric_limits::base_type>::max(); + } + return res; + } + + static constexpr wide::integer lowest() noexcept { return min(); } + static constexpr wide::integer epsilon() noexcept { return 0; } + static constexpr wide::integer round_error() noexcept { return 0; } + static constexpr wide::integer infinity() noexcept { return 0; } + static constexpr wide::integer quiet_NaN() noexcept { return 0; } + static constexpr wide::integer signaling_NaN() noexcept { return 0; } + static constexpr wide::integer denorm_min() noexcept { return 0; } +}; + +// type traits +template +struct common_type, wide::integer> { + using type = std::conditional_t < Bits == Bits2, + wide::integer && + std::is_same_v), + signed, unsigned>>, + std::conditional_t< + Bits2, wide::integer>>; +}; + +template +struct common_type, Arithmetic> { + static_assert(wide::ArithmeticConcept()); + + using type = std::conditional_t < std::is_floating_point_v, Arithmetic, + std::conditional_t, + std::conditional_t || + std::is_signed_v), + Arithmetic, wide::integer>>>>; +}; + +template +struct common_type> + : common_type, Arithmetic> {}; + +} // namespace std + +namespace wide { + +template +struct integer::_impl { + static constexpr size_t _bits = Bits; + static constexpr const unsigned byte_count = Bits / 8; + static constexpr const unsigned item_count = byte_count / sizeof(base_type); + static constexpr const unsigned base_bits = sizeof(base_type) * 8; + + static_assert(Bits % base_bits == 0); + + /// Simple iteration in both directions + static constexpr unsigned little(unsigned idx) { + if constexpr (std::endian::native == std::endian::little) + return idx; + else + return item_count - 1 - idx; + } + static constexpr unsigned big(unsigned idx) { + if constexpr (std::endian::native == std::endian::little) + return item_count - 1 - idx; + else + return idx; + } + static constexpr unsigned any(unsigned idx) { return idx; } + + template + constexpr static bool is_negative(const T& n) noexcept { + if constexpr (std::is_signed_v) + return n < 0; + else + return false; + } + + template + constexpr static bool is_negative(const integer& n) noexcept { + if constexpr (std::is_same_v) + return static_cast(n.items[integer::_impl::big(0)]) < 0; + else + return false; + } + + template + constexpr static auto make_positive(const T& n) noexcept { + if constexpr (std::is_signed_v) + return n < 0 ? -n : n; + else + return n; + } + + template + constexpr static integer make_positive(const integer& n) noexcept { + return is_negative(n) ? integer(operator_unary_minus(n)) : n; + } + + template + __attribute__((no_sanitize("undefined"))) constexpr static auto to_Integral(T f) noexcept { + /// NOTE: this can be called with DB::Decimal, and in this case, result + /// will be wrong + if constexpr (std::is_signed_v) + return static_cast(f); + else + return static_cast(f); + } + + template + constexpr static void wide_integer_from_builtin(integer& self, + Integral rhs) noexcept { + static_assert(sizeof(Integral) <= sizeof(base_type)); + + self.items[little(0)] = _impl::to_Integral(rhs); + + if constexpr (std::is_signed_v) { + if (rhs < 0) { + for (unsigned i = 1; i < item_count; ++i) self.items[little(i)] = -1; + return; + } + } + + for (unsigned i = 1; i < item_count; ++i) self.items[little(i)] = 0; + } + + template <> + constexpr static void wide_integer_from_builtin<__int128>(integer& self, + __int128 rhs) noexcept { + self.items[little(0)] = rhs; + self.items[little(1)] = rhs >> 64; + if (rhs < 0) { + for (unsigned i = 2; i < item_count; ++i) self.items[little(i)] = -1; + return; + } else { + for (unsigned i = 2; i < item_count; ++i) self.items[little(i)] = 0; + } + } + + template + constexpr static void wide_integer_from_tuple_like(integer& self, + const TupleLike& tuple) noexcept { + if constexpr (i < item_count) { + if constexpr (i < std::tuple_size_v) + self.items[i] = std::get(tuple); + else + self.items[i] = 0; + wide_integer_from_tuple_like(self, tuple); + } + } + + template + constexpr static void wide_integer_from_cityhash_uint128( + integer& self, const CityHashUInt128& value) noexcept { + static_assert(sizeof(item_count) >= 2); + + if constexpr (std::endian::native == std::endian::little) + wide_integer_from_tuple_like(self, std::make_pair(value.low64, value.high64)); + else + wide_integer_from_tuple_like(self, std::make_pair(value.high64, value.low64)); + } + + /** + * N.B. t is constructed from double, so max(t) = max(double) ~ 2^310 + * the recursive call happens when t / 2^64 > 2^64, so there won't be more than 5 of them. + * + * t = a1 * max_int + b1, a1 > max_int, b1 < max_int + * a1 = a2 * max_int + b2, a2 > max_int, b2 < max_int + * a_(n - 1) = a_n * max_int + b2, a_n <= max_int <- base case. + */ + template + constexpr static void set_multiplier(integer& self, T t) noexcept { + constexpr uint64_t max_int = std::numeric_limits::max(); + static_assert(std::is_same_v || std::is_same_v); + /// Implementation specific behaviour on overflow (if we don't check here, stack overflow will triggered in bigint_cast). + if constexpr (std::is_same_v) { + if (!std::isfinite(t)) { + self = 0; + return; + } + } else { + if (!boost::math::isfinite(t)) { + self = 0; + return; + } + } + + const T alpha = t / static_cast(max_int); + + /** Here we have to use strict comparison. + * The max_int is 2^64 - 1. + * When casted to floating point type, it will be rounded to the closest representable number, + * which is 2^64. + * But 2^64 is not representable in uint64_t, + * so the maximum representable number will be strictly less. + */ + if (alpha < static_cast(max_int)) + self = static_cast(alpha); + else // max(double) / 2^64 will surely contain less than 52 precision bits, so speed up computations. + set_multiplier(self, static_cast(alpha)); + + self *= max_int; + self += static_cast(t - floor(alpha) * static_cast(max_int)); // += b_i + } + + CONSTEXPR_FROM_DOUBLE static void wide_integer_from_builtin(integer& self, + double rhs) noexcept { + constexpr int64_t max_int = std::numeric_limits::max(); + constexpr int64_t min_int = std::numeric_limits::lowest(); + + /// There are values in int64 that have more than 53 significant bits (in terms of double + /// representation). Such values, being promoted to double, are rounded up or down. If they are rounded up, + /// the result may not fit in 64 bits. + /// The example of such a number is 9.22337e+18. + /// As to_Integral does a static_cast to int64_t, it may result in UB. + /// The necessary check here is that FromDoubleIntermediateType has enough significant (mantissa) bits to store the + /// int64_t max value precisely. + + if (rhs > static_cast(min_int) && + rhs < static_cast(max_int)) { + self = static_cast(rhs); + return; + } + + const FromDoubleIntermediateType rhs_long_double = + (static_cast(rhs) < 0) + ? -static_cast(rhs) + : rhs; + + set_multiplier(self, rhs_long_double); + + if (rhs < 0) self = -self; + } + + template + constexpr static void wide_integer_from_wide_integer( + integer& self, const integer& rhs) noexcept { + constexpr const unsigned min_bits = (Bits < Bits2) ? Bits : Bits2; + constexpr const unsigned to_copy = min_bits / base_bits; + + for (unsigned i = 0; i < to_copy; ++i) + self.items[little(i)] = rhs.items[integer::_impl::little(i)]; + + if constexpr (Bits > Bits2) { + if constexpr (std::is_signed_v) { + if (rhs < 0) { + for (unsigned i = to_copy; i < item_count; ++i) self.items[little(i)] = -1; + return; + } + } + + for (unsigned i = to_copy; i < item_count; ++i) self.items[little(i)] = 0; + } + } + + template + constexpr static bool should_keep_size() { + return sizeof(T) <= byte_count; + } + + constexpr static integer shift_left(const integer& rhs, + unsigned n) noexcept { + integer lhs; + unsigned items_shift = n / base_bits; + + if (unsigned bit_shift = n % base_bits) { + unsigned overflow_shift = base_bits - bit_shift; + + lhs.items[big(0)] = rhs.items[big(items_shift)] << bit_shift; + for (unsigned i = 1; i < item_count - items_shift; ++i) { + lhs.items[big(i - 1)] |= rhs.items[big(items_shift + i)] >> overflow_shift; + lhs.items[big(i)] = rhs.items[big(items_shift + i)] << bit_shift; + } + } else { + for (unsigned i = 0; i < item_count - items_shift; ++i) + lhs.items[big(i)] = rhs.items[big(items_shift + i)]; + } + + for (unsigned i = 0; i < items_shift; ++i) lhs.items[little(i)] = 0; + return lhs; + } + + constexpr static integer shift_right(const integer& rhs, + unsigned n) noexcept { + integer lhs; + unsigned items_shift = n / base_bits; + unsigned bit_shift = n % base_bits; + + if (bit_shift) { + unsigned overflow_shift = base_bits - bit_shift; + + lhs.items[little(0)] = rhs.items[little(items_shift)] >> bit_shift; + for (unsigned i = 1; i < item_count - items_shift; ++i) { + lhs.items[little(i - 1)] |= rhs.items[little(items_shift + i)] << overflow_shift; + lhs.items[little(i)] = rhs.items[little(items_shift + i)] >> bit_shift; + } + } else { + for (unsigned i = 0; i < item_count - items_shift; ++i) + lhs.items[little(i)] = rhs.items[little(items_shift + i)]; + } + + if (is_negative(rhs)) { + if (bit_shift) + lhs.items[big(items_shift)] |= std::numeric_limits::max() + << (base_bits - bit_shift); + + for (unsigned i = 0; i < items_shift; ++i) + lhs.items[big(i)] = std::numeric_limits::max(); + } else { + for (unsigned i = 0; i < items_shift; ++i) lhs.items[big(i)] = 0; + } + + return lhs; + } + +private: + template + constexpr static base_type get_item(const T& x, unsigned idx) { + if constexpr (IsWideInteger::value) { + if (idx < T::_impl::item_count) return x.items[idx]; + return 0; + } else { + if constexpr (sizeof(T) <= sizeof(base_type)) { + if (little(0) == idx) return static_cast(x); + } else if (idx * sizeof(base_type) < sizeof(T)) + return x >> (idx * base_bits); // & std::numeric_limits::max() + return 0; + } + } + + template + constexpr static integer minus(const integer& lhs, T rhs) { + constexpr const unsigned rhs_items = + (sizeof(T) > sizeof(base_type)) ? (sizeof(T) / sizeof(base_type)) : 1; + constexpr const unsigned op_items = (item_count < rhs_items) ? item_count : rhs_items; + + integer res(lhs); + bool underflows[item_count] = {}; + + for (unsigned i = 0; i < op_items; ++i) { + base_type rhs_item = get_item(rhs, little(i)); + base_type& res_item = res.items[little(i)]; + + underflows[i] = res_item < rhs_item; + res_item -= rhs_item; + } + + for (unsigned i = 1; i < item_count; ++i) { + if (underflows[i - 1]) { + base_type& res_item = res.items[little(i)]; + if (res_item == 0) underflows[i] = true; + --res_item; + } + } + + return res; + } + + template + constexpr static integer plus(const integer& lhs, T rhs) { + constexpr const unsigned rhs_items = + (sizeof(T) > sizeof(base_type)) ? (sizeof(T) / sizeof(base_type)) : 1; + constexpr const unsigned op_items = (item_count < rhs_items) ? item_count : rhs_items; + + integer res(lhs); + bool overflows[item_count] = {}; + + for (unsigned i = 0; i < op_items; ++i) { + base_type rhs_item = get_item(rhs, little(i)); + base_type& res_item = res.items[little(i)]; + + res_item += rhs_item; + overflows[i] = res_item < rhs_item; + } + + for (unsigned i = 1; i < item_count; ++i) { + if (overflows[i - 1]) { + base_type& res_item = res.items[little(i)]; + ++res_item; + if (res_item == 0) overflows[i] = true; + } + } + + return res; + } + + template + constexpr static integer multiply(const integer& lhs, + const T& rhs) { + if constexpr (Bits == 256 && sizeof(base_type) == 8) { + /// @sa https://github.com/abseil/abseil-cpp/blob/master/absl/numeric/int128.h + using HalfType = unsigned __int128; + + HalfType a01 = (HalfType(lhs.items[little(1)]) << 64) + lhs.items[little(0)]; + HalfType a23 = (HalfType(lhs.items[little(3)]) << 64) + lhs.items[little(2)]; + HalfType a0 = lhs.items[little(0)]; + HalfType a1 = lhs.items[little(1)]; + + HalfType b01 = rhs; + uint64_t b0 = b01; + uint64_t b1 = 0; + HalfType b23 = 0; + if constexpr (sizeof(T) > 8) b1 = b01 >> 64; + if constexpr (sizeof(T) > 16) + b23 = (HalfType(rhs.items[little(3)]) << 64) + rhs.items[little(2)]; + + HalfType r23 = a23 * b01 + a01 * b23 + a1 * b1; + HalfType r01 = a0 * b0; + HalfType r12 = (r01 >> 64) + (r23 << 64); + HalfType r12_x = a1 * b0; + + integer res; + res.items[little(0)] = r01; + res.items[little(3)] = r23 >> 64; + + if constexpr (sizeof(T) > 8) { + HalfType r12_y = a0 * b1; + r12_x += r12_y; + if (r12_x < r12_y) ++res.items[little(3)]; + } + + r12 += r12_x; + if (r12 < r12_x) ++res.items[little(3)]; + + res.items[little(1)] = r12; + res.items[little(2)] = r12 >> 64; + return res; + } else if constexpr (Bits == 128 && sizeof(base_type) == 8) { + using CompilerUInt128 = unsigned __int128; + CompilerUInt128 a = + (CompilerUInt128(lhs.items[little(1)]) << 64) + + lhs.items[little( + 0)]; // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) + CompilerUInt128 b = + (CompilerUInt128(rhs.items[little(1)]) << 64) + + rhs.items[little( + 0)]; // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) + CompilerUInt128 c = a * b; + integer res; + res.items[little(0)] = c; + res.items[little(1)] = c >> 64; + return res; + } else { + integer res {}; +#if 1 + integer lhs2 = plus(lhs, shift_left(lhs, 1)); + integer lhs3 = plus(lhs2, shift_left(lhs, 2)); +#endif + for (unsigned i = 0; i < item_count; ++i) { + base_type rhs_item = get_item(rhs, little(i)); + unsigned pos = i * base_bits; + + while (rhs_item) { +#if 1 /// optimization + if ((rhs_item & 0x7) == 0x7) { + res = plus(res, shift_left(lhs3, pos)); + rhs_item >>= 3; + pos += 3; + continue; + } + + if ((rhs_item & 0x3) == 0x3) { + res = plus(res, shift_left(lhs2, pos)); + rhs_item >>= 2; + pos += 2; + continue; + } +#endif + if (rhs_item & 1) res = plus(res, shift_left(lhs, pos)); + + rhs_item >>= 1; + ++pos; + } + } + + return res; + } + } + +public: + constexpr static integer operator_unary_tilda( + const integer& lhs) noexcept { + integer res; + + for (unsigned i = 0; i < item_count; ++i) res.items[any(i)] = ~lhs.items[any(i)]; + return res; + } + + constexpr static integer operator_unary_minus( + const integer& lhs) noexcept(std::is_same_v) { + return plus(operator_unary_tilda(lhs), 1); + } + + template + constexpr static auto operator_plus(const integer& lhs, + const T& rhs) noexcept(std::is_same_v) { + if constexpr (should_keep_size()) { + if (is_negative(rhs)) + return minus(lhs, -rhs); + else + return plus(lhs, rhs); + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, integer>:: + _impl::operator_plus(integer(lhs), rhs); + } + } + + template + constexpr static auto operator_minus(const integer& lhs, + const T& rhs) noexcept(std::is_same_v) { + if constexpr (should_keep_size()) { + if (is_negative(rhs)) + return plus(lhs, -rhs); + else + return minus(lhs, rhs); + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, integer>:: + _impl::operator_minus(integer(lhs), rhs); + } + } + + template + constexpr static auto operator_star(const integer& lhs, const T& rhs) { + if constexpr (should_keep_size()) { + integer res; + + if constexpr (std::is_signed_v) { + res = multiply((is_negative(lhs) ? make_positive(lhs) : lhs), + (is_negative(rhs) ? make_positive(rhs) : rhs)); + } else { + res = multiply(lhs, (is_negative(rhs) ? make_positive(rhs) : rhs)); + } + + if (std::is_same_v && is_negative(lhs) != is_negative(rhs)) + res = operator_unary_minus(res); + + return res; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_star(T(lhs), rhs); + } + } + + template + constexpr static bool operator_greater(const integer& lhs, + const T& rhs) noexcept { + if constexpr (should_keep_size()) { + if (std::numeric_limits::is_signed && (is_negative(lhs) != is_negative(rhs))) + return is_negative(rhs); + + integer t = rhs; + for (unsigned i = 0; i < item_count; ++i) { + base_type rhs_item = get_item(t, big(i)); + + if (lhs.items[big(i)] != rhs_item) return lhs.items[big(i)] > rhs_item; + } + + return false; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_greater(T(lhs), + rhs); + } + } + + template + constexpr static bool operator_less(const integer& lhs, const T& rhs) noexcept { + if constexpr (should_keep_size()) { + if (std::numeric_limits::is_signed && (is_negative(lhs) != is_negative(rhs))) + return is_negative(lhs); + + integer t = rhs; + for (unsigned i = 0; i < item_count; ++i) { + base_type rhs_item = get_item(t, big(i)); + + if (lhs.items[big(i)] != rhs_item) return lhs.items[big(i)] < rhs_item; + } + + return false; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_less(T(lhs), rhs); + } + } + + template + constexpr static bool operator_eq(const integer& lhs, const T& rhs) noexcept { + if constexpr (should_keep_size()) { + integer t = rhs; + for (unsigned i = 0; i < item_count; ++i) { + base_type rhs_item = get_item(t, any(i)); + + if (lhs.items[any(i)] != rhs_item) return false; + } + + return true; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_eq(T(lhs), rhs); + } + } + + template + constexpr static auto operator_pipe(const integer& lhs, const T& rhs) noexcept { + if constexpr (should_keep_size()) { + integer res; + + for (unsigned i = 0; i < item_count; ++i) + res.items[little(i)] = lhs.items[little(i)] | get_item(rhs, little(i)); + return res; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_pipe(T(lhs), rhs); + } + } + + template + constexpr static auto operator_amp(const integer& lhs, const T& rhs) noexcept { + if constexpr (should_keep_size()) { + integer res; + + for (unsigned i = 0; i < item_count; ++i) + res.items[little(i)] = lhs.items[little(i)] & get_item(rhs, little(i)); + return res; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, T>::_impl::operator_amp(T(lhs), rhs); + } + } + + template + constexpr static bool is_zero(const T& x) { + bool is_zero = true; + for (auto item : x.items) { + if (item != 0) { + is_zero = false; + break; + } + } + return is_zero; + } + + /// returns quotient as result and remainder in numerator. + template + constexpr static integer divide(integer& numerator, + integer denominator) { + static_assert(std::is_unsigned_v); + + if constexpr (Bits == 128 && sizeof(base_type) == 8) { + using CompilerUInt128 = unsigned __int128; + + CompilerUInt128 a = + (CompilerUInt128(numerator.items[little(1)]) << 64) + + numerator.items[little( + 0)]; // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) + CompilerUInt128 b = + (CompilerUInt128(denominator.items[little(1)]) << 64) + + denominator.items[little( + 0)]; // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) + CompilerUInt128 c = a / b; // NOLINT + + integer res; + res.items[little(0)] = c; + res.items[little(1)] = c >> 64; + + CompilerUInt128 remainder = a - b * c; + numerator.items[little(0)] = remainder; + numerator.items[little(1)] = remainder >> 64; + + return res; + } + + if (is_zero(denominator)) + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Division by zero"); + + integer x = 1; + integer quotient = 0; + + while (!operator_greater(denominator, numerator) && + is_zero(operator_amp(shift_right(denominator, Bits2 - 1), 1))) { + x = shift_left(x, 1); + denominator = shift_left(denominator, 1); + } + + while (!is_zero(x)) { + if (!operator_greater(denominator, numerator)) { + numerator = operator_minus(numerator, denominator); + quotient = operator_pipe(quotient, x); + } + + x = shift_right(x, 1); + denominator = shift_right(denominator, 1); + } + + return quotient; + } + + template + constexpr static auto operator_slash(const integer& lhs, const T& rhs) { + if constexpr (should_keep_size()) { + integer numerator = make_positive(lhs); + integer denominator = make_positive(integer(rhs)); + integer quotient = + integer::_impl::divide(numerator, std::move(denominator)); + + if (std::is_same_v && is_negative(rhs) != is_negative(lhs)) + quotient = operator_unary_minus(quotient); + return quotient; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, + integer>::operator_slash(T(lhs), + rhs); + } + } + + template + constexpr static auto operator_percent(const integer& lhs, const T& rhs) { + if constexpr (should_keep_size()) { + integer remainder = make_positive(lhs); + integer denominator = make_positive(integer(rhs)); + integer::_impl::divide(remainder, std::move(denominator)); + + if (std::is_same_v && is_negative(lhs)) + remainder = operator_unary_minus(remainder); + return remainder; + } else { + static_assert(IsWideInteger::value); + return std::common_type_t, + integer>::operator_percent(T(lhs), + rhs); + } + } + + // ^ + template + constexpr static auto operator_circumflex(const integer& lhs, + const T& rhs) noexcept { + if constexpr (should_keep_size()) { + integer t(rhs); + integer res = lhs; + + for (unsigned i = 0; i < item_count; ++i) res.items[any(i)] ^= t.items[any(i)]; + return res; + } else { + static_assert(IsWideInteger::value); + return T::operator_circumflex(T(lhs), rhs); + } + } + + constexpr static integer from_str(const char* c) { + integer res = 0; + + bool is_neg = std::is_same_v && *c == '-'; + if (is_neg) ++c; + + if (*c == '0' && (*(c + 1) == 'x' || *(c + 1) == 'X')) { // hex + ++c; + ++c; + while (*c) { + if (*c >= '0' && *c <= '9') { + res = multiply(res, 16U); + res = plus(res, *c - '0'); + ++c; + } else if (*c >= 'a' && *c <= 'f') { + res = multiply(res, 16U); + res = plus(res, *c - 'a' + 10U); + ++c; + } else if (*c >= 'A' && + *c <= 'F') { // tolower must be used, but it is not constexpr + res = multiply(res, 16U); + res = plus(res, *c - 'A' + 10U); + ++c; + } else + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Invalid char from"); + } + } else { // dec + while (*c) { + if (*c < '0' || *c > '9') + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Invalid char from"); + + res = multiply(res, 10U); + res = plus(res, *c - '0'); + ++c; + } + } + + if (is_neg) res = operator_unary_minus(res); + + return res; + } +}; + +// Members + +template +template +constexpr integer::integer(T rhs) noexcept : items {} { + if constexpr (IsWideInteger::value) + _impl::wide_integer_from_wide_integer(*this, rhs); + else if constexpr (IsTupleLike::value) + _impl::wide_integer_from_tuple_like(*this, rhs); + else if constexpr (std::is_same_v, CityHash_v1_0_2::uint128>) + _impl::wide_integer_from_cityhash_uint128(*this, rhs); + else + _impl::wide_integer_from_builtin(*this, rhs); +} + +template +template +constexpr integer::integer(std::initializer_list il) noexcept : items {} { + if (il.size() == 1) { + if constexpr (IsWideInteger::value) + _impl::wide_integer_from_wide_integer(*this, *il.begin()); + else if constexpr (IsTupleLike::value) + _impl::wide_integer_from_tuple_like(*this, *il.begin()); + else if constexpr (std::is_same_v, CityHash_v1_0_2::uint128>) + _impl::wide_integer_from_cityhash_uint128(*this, *il.begin()); + else + _impl::wide_integer_from_builtin(*this, *il.begin()); + } else if (il.size() == 0) { + _impl::wide_integer_from_builtin(*this, 0); + } else { + auto it = il.begin(); + for (unsigned i = 0; i < _impl::item_count; ++i) { + if (it < il.end()) { + items[_impl::little(i)] = *it; + ++it; + } else + items[_impl::little(i)] = 0; + } + } +} + +template +template +constexpr integer& integer::operator=( + const integer& rhs) noexcept { + _impl::wide_integer_from_wide_integer(*this, rhs); + return *this; +} + +template +template +constexpr integer& integer::operator=(T rhs) noexcept { + if constexpr (IsTupleLike::value) + _impl::wide_integer_from_tuple_like(*this, rhs); + else if constexpr (std::is_same_v, CityHash_v1_0_2::uint128>) + _impl::wide_integer_from_cityhash_uint128(*this, rhs); + else + _impl::wide_integer_from_builtin(*this, rhs); + return *this; +} + +template +template +constexpr integer& integer::operator*=(const T& rhs) { + *this = *this * rhs; + return *this; +} + +template +template +constexpr integer& integer::operator/=(const T& rhs) { + *this = *this / rhs; + return *this; +} + +template +template +constexpr integer& integer::operator+=(const T& rhs) noexcept( + std::is_same_v) { + *this = *this + rhs; + return *this; +} + +template +template +constexpr integer& integer::operator-=(const T& rhs) noexcept( + std::is_same_v) { + *this = *this - rhs; + return *this; +} + +template +template +constexpr integer& integer::operator%=(const T& rhs) { + *this = *this % rhs; + return *this; +} + +template +template +constexpr integer& integer::operator&=(const T& rhs) noexcept { + *this = *this & rhs; + return *this; +} + +template +template +constexpr integer& integer::operator|=(const T& rhs) noexcept { + *this = *this | rhs; + return *this; +} + +template +template +constexpr integer& integer::operator^=(const T& rhs) noexcept { + *this = *this ^ rhs; + return *this; +} + +template +constexpr integer& integer::operator<<=(int n) noexcept { + if (static_cast(n) >= Bits) + *this = 0; + else if (n > 0) + *this = _impl::shift_left(*this, n); + return *this; +} + +template +constexpr integer& integer::operator>>=(int n) noexcept { + if (static_cast(n) >= Bits) { + if (_impl::is_negative(*this)) + *this = -1; + else + *this = 0; + } else if (n > 0) + *this = _impl::shift_right(*this, n); + return *this; +} + +template +constexpr integer& integer::operator++() noexcept( + std::is_same_v) { + *this = _impl::operator_plus(*this, 1); + return *this; +} + +template +constexpr integer integer::operator++(int) noexcept( + std::is_same_v) { + auto tmp = *this; + *this = _impl::operator_plus(*this, 1); + return tmp; +} + +template +constexpr integer& integer::operator--() noexcept( + std::is_same_v) { + *this = _impl::operator_minus(*this, 1); + return *this; +} + +template +constexpr integer integer::operator--(int) noexcept( + std::is_same_v) { + auto tmp = *this; + *this = _impl::operator_minus(*this, 1); + return tmp; +} + +template +constexpr integer::operator bool() const noexcept { + return !_impl::operator_eq(*this, 0); +} + +template +template +constexpr integer::operator T() const noexcept { + static_assert(std::numeric_limits::is_integer); + + /// NOTE: memcpy will suffice, but unfortunately, this function is constexpr. + + using UnsignedT = std::make_unsigned_t; + + UnsignedT res {}; + for (unsigned i = 0; + i < _impl::item_count && i < (sizeof(T) + sizeof(base_type) - 1) / sizeof(base_type); ++i) + res += UnsignedT(items[_impl::little(i)]) + << (sizeof(base_type) * 8 * + i); // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) + + return res; +} + +template +constexpr integer::operator long double() const noexcept { + if (_impl::operator_eq(*this, 0)) return 0; + + integer tmp = *this; + if (_impl::is_negative(*this)) tmp = -tmp; + + long double res = 0; + for (unsigned i = 0; i < _impl::item_count; ++i) { + long double t = res; + res *= static_cast(std::numeric_limits::max()); + res += t; + res += tmp.items[_impl::big(i)]; + } + + if (_impl::is_negative(*this)) res = -res; + + return res; +} + +template +constexpr integer::operator double() const noexcept { + return static_cast(static_cast(*this)); +} + +template +constexpr integer::operator float() const noexcept { + return static_cast(static_cast(*this)); +} + +// Unary operators +template +constexpr integer operator~(const integer& lhs) noexcept { + return integer::_impl::operator_unary_tilda(lhs); +} + +template +constexpr integer operator-(const integer& lhs) noexcept( + std::is_same_v) { + return integer::_impl::operator_unary_minus(lhs); +} + +template +constexpr integer operator+(const integer& lhs) noexcept( + std::is_same_v) { + return lhs; +} + +#define CT(x) \ + std::common_type_t, std::decay_t> { \ + x \ + } + +// Binary operators +template +std::common_type_t, integer> constexpr operator*( + const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_star( + lhs, rhs); +} + +template +std::common_type_t constexpr operator*(const Arithmetic& lhs, + const Arithmetic2& rhs) { + return CT(lhs) * CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator/( + const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_slash(lhs, rhs); +} +template +std::common_type_t constexpr operator/(const Arithmetic& lhs, + const Arithmetic2& rhs) { + return CT(lhs) / CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator+( + const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_plus( + lhs, rhs); +} +template +std::common_type_t constexpr operator+(const Arithmetic& lhs, + const Arithmetic2& rhs) { + return CT(lhs) + CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator-( + const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_minus(lhs, rhs); +} +template +std::common_type_t constexpr operator-(const Arithmetic& lhs, + const Arithmetic2& rhs) { + return CT(lhs) - CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator%( + const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_percent(lhs, rhs); +} +template +std::common_type_t constexpr operator%(const Integral& lhs, + const Integral2& rhs) { + return CT(lhs) % CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator&( + const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_amp( + lhs, rhs); +} +template +std::common_type_t constexpr operator&(const Integral& lhs, + const Integral2& rhs) { + return CT(lhs) & CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator|( + const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_pipe( + lhs, rhs); +} +template +std::common_type_t constexpr operator|(const Integral& lhs, + const Integral2& rhs) { + return CT(lhs) | CT(rhs); +} + +template +std::common_type_t, integer> constexpr operator^( + const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_circumflex(lhs, rhs); +} +template +std::common_type_t constexpr operator^(const Integral& lhs, + const Integral2& rhs) { + return CT(lhs) ^ CT(rhs); +} + +template +constexpr integer operator<<(const integer& lhs, int n) noexcept { + if (static_cast(n) >= Bits) return integer(0); + if (n <= 0) return lhs; + return integer::_impl::shift_left(lhs, n); +} +template +constexpr integer operator>>(const integer& lhs, int n) noexcept { + if (static_cast(n) >= Bits) return integer(0); + if (n <= 0) return lhs; + return integer::_impl::shift_right(lhs, n); +} + +template +constexpr bool operator<(const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_less( + lhs, rhs); +} +template +constexpr bool operator<(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) < CT(rhs); +} + +template +constexpr bool operator>(const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_greater(lhs, rhs); +} +template +constexpr bool operator>(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) > CT(rhs); +} + +template +constexpr bool operator<=(const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_less( + lhs, rhs) || + std::common_type_t, integer>::_impl::operator_eq( + lhs, rhs); +} +template +constexpr bool operator<=(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) <= CT(rhs); +} + +template +constexpr bool operator>=(const integer& lhs, const integer& rhs) { + return std::common_type_t, + integer>::_impl::operator_greater(lhs, rhs) || + std::common_type_t, integer>::_impl::operator_eq( + lhs, rhs); +} +template +constexpr bool operator>=(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) >= CT(rhs); +} + +template +constexpr bool operator==(const integer& lhs, const integer& rhs) { + return std::common_type_t, integer>::_impl::operator_eq( + lhs, rhs); +} +template +constexpr bool operator==(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) == CT(rhs); +} + +template +constexpr bool operator!=(const integer& lhs, const integer& rhs) { + return !std::common_type_t, integer>::_impl::operator_eq( + lhs, rhs); +} +template +constexpr bool operator!=(const Arithmetic& lhs, const Arithmetic2& rhs) { + return CT(lhs) != CT(rhs); +} + +template +constexpr auto operator<=>(const integer& lhs, const integer& rhs) { + return std::strong_ordering::equivalent; +} +template +constexpr auto operator<=>(const Arithmetic& lhs, const Arithmetic2& rhs) { + return std::strong_ordering::equivalent; +} + +#undef CT + +} // namespace wide + +namespace std { + +template +struct hash> { + std::size_t operator()(const wide::integer& lhs) const { + static_assert(Bits % (sizeof(size_t) * 8) == 0); + + const auto* ptr = reinterpret_cast(lhs.items); + unsigned count = Bits / (sizeof(size_t) * 8); + + size_t res = 0; + for (unsigned i = 0; i < count; ++i) res ^= ptr[i]; + return res; + } +}; + +} // namespace std + +// NOLINTEND(*) diff --git a/be/src/vec/core/wide_integer_to_string.h b/be/src/vec/core/wide_integer_to_string.h new file mode 100644 index 000000000000000..b08eca7b33a3fee --- /dev/null +++ b/be/src/vec/core/wide_integer_to_string.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/base/base/wide_integer_to_string.h +// and modified by Doris +#pragma once + +#include + +#include +#include + +#include "wide_integer.h" + +namespace wide { + +template +inline std::string to_string(const integer& n) { + std::string res; + if (integer::_impl::operator_eq(n, 0U)) return "0"; + + integer t; + bool is_neg = integer::_impl::is_negative(n); + if (is_neg) + t = integer::_impl::operator_unary_minus(n); + else + t = n; + + while (!integer::_impl::operator_eq(t, 0U)) { + res.insert(res.begin(), + '0' + char(integer::_impl::operator_percent(t, 10U))); + t = integer::_impl::operator_slash(t, 10U); + } + + if (is_neg) res.insert(res.begin(), '-'); + return res; +} + +} // namespace wide + +template +std::ostream& operator<<(std::ostream& out, const wide::integer& value) { + return out << to_string(value); +} + +/// See https://fmt.dev/latest/api.html#formatting-user-defined-types +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) { + const auto* it = ctx.begin(); + const auto* end = ctx.end(); + + /// Only support {}. + if (it != end && *it != '}') throw format_error("invalid format"); + + return it; + } + + template + auto format(const wide::integer& value, FormatContext& ctx) { + return fmt::format_to(ctx.out(), "{}", to_string(value)); + } +}; diff --git a/be/src/vec/data_types/convert_field_to_type.cpp b/be/src/vec/data_types/convert_field_to_type.cpp index ba49257898051d2..b5c4263181eb49d 100644 --- a/be/src/vec/data_types/convert_field_to_type.cpp +++ b/be/src/vec/data_types/convert_field_to_type.cpp @@ -82,6 +82,9 @@ class FieldVisitorToStringSimple : public StaticVisitor { [[noreturn]] String operator()(const DecimalField& x) const { LOG(FATAL) << "not implemeted"; } + [[noreturn]] String operator()(const DecimalField& x) const { + LOG(FATAL) << "not implemeted"; + } }; namespace { diff --git a/be/src/vec/data_types/data_type.cpp b/be/src/vec/data_types/data_type.cpp index 8b7a094dcf5b401..48d37b38c397a26 100644 --- a/be/src/vec/data_types/data_type.cpp +++ b/be/src/vec/data_types/data_type.cpp @@ -139,6 +139,8 @@ PGenericType_TypeId IDataType::get_pdata_type(const IDataType* data_type) { return PGenericType::DECIMAL128; case TypeIndex::Decimal128I: return PGenericType::DECIMAL128I; + case TypeIndex::Decimal256: + return PGenericType::DECIMAL256; case TypeIndex::String: return PGenericType::STRING; case TypeIndex::Date: diff --git a/be/src/vec/data_types/data_type.h b/be/src/vec/data_types/data_type.h index 2aee6fdb1e47ab9..fdfc0fde82a92f2 100644 --- a/be/src/vec/data_types/data_type.h +++ b/be/src/vec/data_types/data_type.h @@ -286,8 +286,10 @@ struct WhichDataType { bool is_decimal64() const { return idx == TypeIndex::Decimal64; } bool is_decimal128() const { return idx == TypeIndex::Decimal128; } bool is_decimal128i() const { return idx == TypeIndex::Decimal128I; } + bool is_decimal256() const { return idx == TypeIndex::Decimal256; } bool is_decimal() const { - return is_decimal32() || is_decimal64() || is_decimal128() || is_decimal128i(); + return is_decimal32() || is_decimal64() || is_decimal128() || is_decimal128i() || + is_decimal256(); } bool is_float32() const { return idx == TypeIndex::Float32; } diff --git a/be/src/vec/data_types/data_type_decimal.cpp b/be/src/vec/data_types/data_type_decimal.cpp index 2f71ee736c9e187..f69d169179be897 100644 --- a/be/src/vec/data_types/data_type_decimal.cpp +++ b/be/src/vec/data_types/data_type_decimal.cpp @@ -35,6 +35,7 @@ #include "vec/common/int_exp.h" #include "vec/common/string_buffer.hpp" #include "vec/common/typeid_cast.h" +#include "vec/core/types.h" #include "vec/io/io_helper.h" #include "vec/io/reader_buffer.h" @@ -166,10 +167,10 @@ bool DataTypeDecimal::parse_from_string(const std::string& str, T* res) const DataTypePtr create_decimal(UInt64 precision_value, UInt64 scale_value, bool use_v2) { if (precision_value < min_decimal_precision() || - precision_value > max_decimal_precision()) { + precision_value > max_decimal_precision()) { throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, "Wrong precision {}, min: {}, max: {}", precision_value, - min_decimal_precision(), max_decimal_precision()); + min_decimal_precision(), max_decimal_precision()); } if (static_cast(scale_value) > precision_value) { @@ -187,8 +188,10 @@ DataTypePtr create_decimal(UInt64 precision_value, UInt64 scale_value, bool use_ return std::make_shared>(precision_value, scale_value); } else if (precision_value <= max_decimal_precision()) { return std::make_shared>(precision_value, scale_value); + } else if (precision_value <= max_decimal_precision()) { + return std::make_shared>(precision_value, scale_value); } - return std::make_shared>(precision_value, scale_value); + return std::make_shared>(precision_value, scale_value); } template <> @@ -211,10 +214,16 @@ Decimal128I DataTypeDecimal::get_scale_multiplier(UInt32 scale) { return common::exp10_i128(scale); } +template <> +Decimal256 DataTypeDecimal::get_scale_multiplier(UInt32 scale) { + return Decimal256(common::exp10_i256(scale)); +} + /// Explicit template instantiations. template class DataTypeDecimal; template class DataTypeDecimal; template class DataTypeDecimal; template class DataTypeDecimal; +template class DataTypeDecimal; } // namespace doris::vectorized diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index c704c90365e0090..32c295681adc9ed 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -34,6 +34,7 @@ // IWYU pragma: no_include #include "common/compiler_util.h" // IWYU pragma: keep +#include "common/consts.h" #include "common/logging.h" #include "common/status.h" #include "olap/olap_common.h" @@ -74,19 +75,23 @@ constexpr size_t max_decimal_precision() { } template <> constexpr size_t max_decimal_precision() { - return 9; + return BeConsts::MAX_DECIMAL32_PRECISION; } template <> constexpr size_t max_decimal_precision() { - return 18; + return BeConsts::MAX_DECIMAL64_PRECISION; } template <> constexpr size_t max_decimal_precision() { - return 38; + return BeConsts::MAX_DECIMAL128_PRECISION; } template <> constexpr size_t max_decimal_precision() { - return 38; + return BeConsts::MAX_DECIMAL128_PRECISION; +} +template <> +constexpr size_t max_decimal_precision() { + return BeConsts::MAX_DECIMAL256_PRECISION; } DataTypePtr create_decimal(UInt64 precision, UInt64 scale, bool use_v2); @@ -155,6 +160,9 @@ class DataTypeDecimal final : public IDataType { if constexpr (std::is_same_v, TypeId>) { return TYPE_DECIMAL128I; } + if constexpr (std::is_same_v, TypeId>) { + return TYPE_DECIMAL256; + } return TYPE_DECIMALV2; } @@ -168,6 +176,9 @@ class DataTypeDecimal final : public IDataType { if constexpr (std::is_same_v, TypeId>) { return TPrimitiveType::DECIMAL128I; } + if constexpr (std::is_same_v, TypeId>) { + return TPrimitiveType::DECIMAL256; + } LOG(FATAL) << "__builtin_unreachable"; __builtin_unreachable(); } @@ -254,7 +265,7 @@ class DataTypeDecimal final : public IDataType { return x % get_scale_multiplier(); } - T max_whole_value() const { return get_scale_multiplier(max_precision() - scale) - 1; } + T max_whole_value() const { return get_scale_multiplier(max_precision() - scale) - T(1); } bool can_store_whole(T x) const { T max = max_whole_value(); @@ -326,12 +337,12 @@ DataTypePtr decimal_result_type(const DataTypeDecimal& tx, const DataTypeDeci scale + 1; if (is_multiply) { scale = tx.get_scale() + ty.get_scale(); - precision = std::min(multiply_precision, max_decimal_precision()); + precision = std::min(multiply_precision, max_decimal_precision()); } else if (is_divide) { scale = tx.get_scale(); - precision = std::min(divide_precision, max_decimal_precision()); + precision = std::min(divide_precision, max_decimal_precision()); } else if (is_plus_minus) { - precision = std::min(plus_minus_precision, max_decimal_precision()); + precision = std::min(plus_minus_precision, max_decimal_precision()); } return create_decimal(precision, scale, false); } @@ -355,6 +366,9 @@ inline UInt32 get_decimal_scale(const IDataType& data_type, UInt32 default_value if (auto* decimal_type = check_decimal(data_type)) { return decimal_type->get_scale(); } + if (auto* decimal_type = check_decimal(data_type)) { + return decimal_type->get_scale(); + } return default_value; } @@ -370,6 +384,8 @@ template <> inline constexpr bool IsDataTypeDecimal> = true; template <> inline constexpr bool IsDataTypeDecimal> = true; +template <> +inline constexpr bool IsDataTypeDecimal> = true; template constexpr bool IsDataTypeDecimalV2 = false; @@ -381,6 +397,11 @@ constexpr bool IsDataTypeDecimal128I = false; template <> inline constexpr bool IsDataTypeDecimal128I> = true; +template +constexpr bool IsDataTypeDecimal256 = false; +template <> +inline constexpr bool IsDataTypeDecimal256> = true; + template constexpr bool IsDataTypeDecimalOrNumber = IsDataTypeDecimal || IsDataTypeNumber; @@ -414,7 +435,8 @@ ToDataType::FieldType convert_decimals(const typename FromDataType::FieldType& v } } else { converted_value = - value / DataTypeDecimal::get_scale_multiplier(scale_from - scale_to); + static_cast(value) / + DataTypeDecimal::get_scale_multiplier(scale_from - scale_to); } if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType)) { @@ -457,8 +479,9 @@ void convert_decimal_cols( DataTypeDecimal::get_scale_multiplier(scale_to - scale_from); MaxNativeType res; for (size_t i = 0; i < sz; i++) { - if (std::is_same_v) { - if (common::mul_overflow(static_cast(vec_from[i]), multiplier, + if constexpr (std::is_same_v || + std::is_same_v) { + if (common::mul_overflow(static_cast(vec_from[i].value), multiplier, res)) { if (overflow_flag) { overflow_flag[i] = 1; @@ -466,10 +489,10 @@ void convert_decimal_cols( vec_to[i] = res < 0 ? type_limit::min() : type_limit::max(); } else { - vec_to[i] = res; + vec_to[i] = ToFieldType(res); } } else { - vec_to[i] = vec_from[i] * multiplier; + vec_to[i] = ToFieldType(vec_from[i].value * multiplier); } } } else { @@ -477,9 +500,9 @@ void convert_decimal_cols( DataTypeDecimal::get_scale_multiplier(scale_from - scale_to); for (size_t i = 0; i < sz; i++) { if (vec_from[i] >= FromFieldType(0)) { - vec_to[i] = (vec_from[i] + multiplier / 2) / multiplier; + vec_to[i] = ToFieldType((vec_from[i].value + multiplier / 2) / multiplier); } else { - vec_to[i] = (vec_from[i] - multiplier / 2) / multiplier; + vec_to[i] = ToFieldType((vec_from[i].value - multiplier / 2) / multiplier); } } } @@ -512,7 +535,8 @@ ToDataType::FieldType convert_from_decimal(const typename FromDataType::FieldTyp if constexpr (IsDecimalV2) { return binary_cast(value); } else { - return static_cast(value) / FromDataType::get_scale_multiplier(scale); + return static_cast(value.value) / + FromDataType::get_scale_multiplier(scale).value; } } else { FromFieldType converted_value = @@ -562,7 +586,7 @@ ToDataType::FieldType convert_to_decimal(const typename FromDataType::FieldType& VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range"; return type_limit::max(); } - return out; + return typename ToDataType::FieldType(out); } else { if constexpr (std::is_same_v) { if (value > static_cast(std::numeric_limits::max())) { @@ -576,14 +600,16 @@ ToDataType::FieldType convert_to_decimal(const typename FromDataType::FieldType& template requires IsDecimalNumber typename T::NativeType max_decimal_value(UInt32 precision) { - return type_limit::max() / DataTypeDecimal::get_scale_multiplier( - (UInt32)(max_decimal_precision() - precision)); + return type_limit::max().value / DataTypeDecimal::get_scale_multiplier( + (UInt32)(max_decimal_precision() - precision)) + .value; } template requires IsDecimalNumber typename T::NativeType min_decimal_value(UInt32 precision) { - return type_limit::min() / DataTypeDecimal::get_scale_multiplier( - (UInt32)(max_decimal_precision() - precision)); + return type_limit::min().value / DataTypeDecimal::get_scale_multiplier( + (UInt32)(max_decimal_precision() - precision)) + .value; } } // namespace doris::vectorized diff --git a/be/src/vec/data_types/data_type_factory.cpp b/be/src/vec/data_types/data_type_factory.cpp index 4ab836141b80aaf..b5700cb7f0bc1b1 100644 --- a/be/src/vec/data_types/data_type_factory.cpp +++ b/be/src/vec/data_types/data_type_factory.cpp @@ -187,6 +187,7 @@ DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, bo case TYPE_DECIMAL32: case TYPE_DECIMAL64: case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: nested = vectorized::create_decimal(col_desc.precision, col_desc.scale, false); break; // Just Mock A NULL Type in Vec Exec Engine @@ -302,6 +303,10 @@ DataTypePtr DataTypeFactory::create_data_type(const TypeIndex& type_index, bool nested = std::make_shared>(BeConsts::MAX_DECIMAL128_PRECISION, 0); break; + case TypeIndex::Decimal256: + nested = std::make_shared>(BeConsts::MAX_DECIMAL256_PRECISION, + 0); + break; case TypeIndex::JSONB: nested = std::make_shared(); break; @@ -394,6 +399,7 @@ DataTypePtr DataTypeFactory::_create_primitive_data_type(const FieldType& type, case FieldType::OLAP_FIELD_TYPE_DECIMAL32: case FieldType::OLAP_FIELD_TYPE_DECIMAL64: case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: result = vectorized::create_decimal(precision, scale, false); break; default: @@ -479,6 +485,10 @@ DataTypePtr DataTypeFactory::create_data_type(const PColumnMeta& pcolumn) { nested = std::make_shared>(pcolumn.decimal_param().precision(), pcolumn.decimal_param().scale()); break; + case PGenericType::DECIMAL256: + nested = std::make_shared>(pcolumn.decimal_param().precision(), + pcolumn.decimal_param().scale()); + break; case PGenericType::BITMAP: nested = std::make_shared(); break; diff --git a/be/src/vec/data_types/get_least_supertype.cpp b/be/src/vec/data_types/get_least_supertype.cpp index be9dd5c05c4f1c3..9db1271450d667e 100644 --- a/be/src/vec/data_types/get_least_supertype.cpp +++ b/be/src/vec/data_types/get_least_supertype.cpp @@ -358,10 +358,12 @@ void get_least_supertype(const DataTypes& types, DataTypePtr* type, bool compati UInt32 have_decimal64 = type_ids.count(TypeIndex::Decimal64); UInt32 have_decimal128 = type_ids.count(TypeIndex::Decimal128); UInt32 have_decimal128i = type_ids.count(TypeIndex::Decimal128I); + UInt32 have_decimal256 = type_ids.count(TypeIndex::Decimal256); - if (have_decimal32 || have_decimal64 || have_decimal128 || have_decimal128i) { - UInt32 num_supported = - have_decimal32 + have_decimal64 + have_decimal128 + have_decimal128i; + if (have_decimal32 || have_decimal64 || have_decimal128 || have_decimal128i || + have_decimal256) { + UInt32 num_supported = have_decimal32 + have_decimal64 + have_decimal128 + + have_decimal128i + have_decimal256; std::vector int_ids = { TypeIndex::Int8, TypeIndex::UInt8, TypeIndex::Int16, TypeIndex::UInt16, @@ -401,7 +403,7 @@ void get_least_supertype(const DataTypes& types, DataTypePtr* type, bool compati min_precision = DataTypeDecimal::max_precision(); } - if (min_precision > DataTypeDecimal::max_precision()) { + if (min_precision > DataTypeDecimal::max_precision()) { LOG(INFO) << fmt::format("{} because the least supertype is Decimal({},{})", get_exception_message_prefix(types), min_precision, max_scale); @@ -412,6 +414,11 @@ void get_least_supertype(const DataTypes& types, DataTypePtr* type, bool compati doris::ErrorCode::INVALID_ARGUMENT); } + if (have_decimal256 || min_precision > DataTypeDecimal::max_precision()) { + *type = std::make_shared>( + DataTypeDecimal::max_precision(), max_scale); + return; + } if (have_decimal128 || min_precision > DataTypeDecimal::max_precision()) { *type = std::make_shared>( DataTypeDecimal::max_precision(), max_scale); diff --git a/be/src/vec/data_types/number_traits.h b/be/src/vec/data_types/number_traits.h index 8b87e55d93c6df4..2d05d65681fbd13 100644 --- a/be/src/vec/data_types/number_traits.h +++ b/be/src/vec/data_types/number_traits.h @@ -26,6 +26,7 @@ #include "vec/columns/column_vector.h" #include "vec/common/uint128.h" #include "vec/core/types.h" +#include "vec/core/wide_integer.h" namespace doris::vectorized { @@ -76,6 +77,10 @@ struct Construct { using Type = Int128; }; template <> +struct Construct { + using Type = wide::Int256; +}; +template <> struct Construct { using Type = Float32; }; @@ -112,6 +117,10 @@ struct Construct { using Type = Int128; }; template <> +struct Construct { + using Type = wide::Int256; +}; +template <> struct Construct { using Type = Float32; }; diff --git a/be/src/vec/data_types/serde/data_type_decimal_serde.cpp b/be/src/vec/data_types/serde/data_type_decimal_serde.cpp index e6628d8c5b3dd6c..dae309119b69715 100644 --- a/be/src/vec/data_types/serde/data_type_decimal_serde.cpp +++ b/be/src/vec/data_types/serde/data_type_decimal_serde.cpp @@ -111,6 +111,7 @@ void DataTypeDecimalSerDe::write_column_to_arrow(const IColumn& column, const checkArrowStatus(builder.Append(value), column.get_name(), array_builder->type()->name()); } + // TODO: decimal256 } else if constexpr (std::is_same_v) { std::shared_ptr s_decimal_ptr = std::make_shared(38, col.get_scale()); @@ -277,5 +278,6 @@ template class DataTypeDecimalSerDe; template class DataTypeDecimalSerDe; template class DataTypeDecimalSerDe; template class DataTypeDecimalSerDe; +template class DataTypeDecimalSerDe; } // namespace vectorized } // namespace doris diff --git a/be/src/vec/data_types/serde/data_type_decimal_serde.h b/be/src/vec/data_types/serde/data_type_decimal_serde.h index 5085d4036149d02..4843a6b90e7187a 100644 --- a/be/src/vec/data_types/serde/data_type_decimal_serde.h +++ b/be/src/vec/data_types/serde/data_type_decimal_serde.h @@ -28,6 +28,7 @@ #include "common/status.h" #include "data_type_serde.h" #include "olap/olap_common.h" +#include "runtime/define_primitive_type.h" #include "util/jsonb_document.h" #include "util/jsonb_writer.h" #include "vec/columns/column.h" @@ -60,6 +61,9 @@ class DataTypeDecimalSerDe : public DataTypeSerDe { if constexpr (std::is_same_v, TypeId>) { return TYPE_DECIMALV2; } + if constexpr (std::is_same_v, TypeId>) { + return TYPE_DECIMAL256; + } LOG(FATAL) << "__builtin_unreachable"; __builtin_unreachable(); } @@ -128,6 +132,8 @@ Status DataTypeDecimalSerDe::write_column_to_pb(const IColumn& column, PValue ptype->set_id(PGenericType::DECIMAL128); } else if constexpr (std::is_same_v) { ptype->set_id(PGenericType::DECIMAL128I); + } else if constexpr (std::is_same_v) { + ptype->set_id(PGenericType::DECIMAL256); } else if constexpr (std::is_same_v>) { ptype->set_id(PGenericType::INT32); } else if constexpr (std::is_same_v>) { @@ -143,10 +149,12 @@ Status DataTypeDecimalSerDe::write_column_to_pb(const IColumn& column, PValue return Status::OK(); } +// TODO: decimal256 template Status DataTypeDecimalSerDe::read_column_from_pb(IColumn& column, const PValues& arg) const { if constexpr (std::is_same_v> || std::is_same_v || - std::is_same_v> || std::is_same_v>) { + std::is_same_v || std::is_same_v> || + std::is_same_v>) { column.resize(arg.bytes_value_size()); auto& data = reinterpret_cast&>(column).get_data(); for (int i = 0; i < arg.bytes_value_size(); ++i) { @@ -164,6 +172,7 @@ void DataTypeDecimalSerDe::write_one_cell_to_jsonb(const IColumn& column, Jso int row_num) const { StringRef data_ref = column.get_data_at(row_num); result.writeKey(col_id); + // TODO: decimal256 if constexpr (std::is_same_v>) { Decimal128::NativeType val = *reinterpret_cast(data_ref.data); @@ -188,6 +197,7 @@ template void DataTypeDecimalSerDe::read_one_cell_from_jsonb(IColumn& column, const JsonbValue* arg) const { auto& col = reinterpret_cast&>(column); + // TODO: decimal256 if constexpr (std::is_same_v>) { col.insert_value(static_cast(arg)->val()); } else if constexpr (std::is_same_v) { diff --git a/be/src/vec/exec/format/parquet/byte_array_dict_decoder.cpp b/be/src/vec/exec/format/parquet/byte_array_dict_decoder.cpp index 1e09890a9807fd8..6f5f36a33a972d1 100644 --- a/be/src/vec/exec/format/parquet/byte_array_dict_decoder.cpp +++ b/be/src/vec/exec/format/parquet/byte_array_dict_decoder.cpp @@ -169,6 +169,7 @@ Status ByteArrayDictDecoder::_decode_values(MutableColumnPtr& doris_column, Data return _decode_binary_decimal(doris_column, data_type, select_vector); case TypeIndex::Decimal128I: return _decode_binary_decimal(doris_column, data_type, select_vector); + // TODO: decimal256 default: break; } diff --git a/be/src/vec/exec/format/parquet/byte_array_plain_decoder.cpp b/be/src/vec/exec/format/parquet/byte_array_plain_decoder.cpp index 9a032b540b3757f..e91f9f1db94ce2e 100644 --- a/be/src/vec/exec/format/parquet/byte_array_plain_decoder.cpp +++ b/be/src/vec/exec/format/parquet/byte_array_plain_decoder.cpp @@ -118,6 +118,7 @@ Status ByteArrayPlainDecoder::_decode_values(MutableColumnPtr& doris_column, Dat return _decode_binary_decimal(doris_column, data_type, select_vector); case TypeIndex::Decimal128I: return _decode_binary_decimal(doris_column, data_type, select_vector); + // TODO: decimal256 default: break; } diff --git a/be/src/vec/exec/format/parquet/fix_length_dict_decoder.hpp b/be/src/vec/exec/format/parquet/fix_length_dict_decoder.hpp index a30c2dff3d165e0..35880cfcdd3080b 100644 --- a/be/src/vec/exec/format/parquet/fix_length_dict_decoder.hpp +++ b/be/src/vec/exec/format/parquet/fix_length_dict_decoder.hpp @@ -150,6 +150,7 @@ class FixLengthDictDecoder final : public BaseDictDecoder { select_vector); } break; + // TODO: decimal256 case TypeIndex::String: [[fallthrough]]; case TypeIndex::FixedString: @@ -512,6 +513,7 @@ class FixLengthDictDecoder final : public BaseDictDecoder { select_vector); } break; + // TODO: decimal256 case TypeIndex::String: [[fallthrough]]; case TypeIndex::FixedString: diff --git a/be/src/vec/exec/format/parquet/fix_length_plain_decoder.cpp b/be/src/vec/exec/format/parquet/fix_length_plain_decoder.cpp index af464c155459400..8e6f6ebb67ff049 100644 --- a/be/src/vec/exec/format/parquet/fix_length_plain_decoder.cpp +++ b/be/src/vec/exec/format/parquet/fix_length_plain_decoder.cpp @@ -173,6 +173,7 @@ Status FixLengthPlainDecoder::_decode_values(MutableColumnPtr& doris_column, Dat select_vector); } break; + // TODO: decimal256 case TypeIndex::String: [[fallthrough]]; case TypeIndex::FixedString: diff --git a/be/src/vec/exec/jni_connector.cpp b/be/src/vec/exec/jni_connector.cpp index 88b860f64b117c0..00fa83a51898392 100644 --- a/be/src/vec/exec/jni_connector.cpp +++ b/be/src/vec/exec/jni_connector.cpp @@ -283,6 +283,7 @@ Status JniConnector::_fill_column(ColumnPtr& doris_column, DataTypePtr& data_typ data_column, reinterpret_cast(_next_meta_as_ptr()), num_rows); FOR_LOGICAL_NUMERIC_TYPES(DISPATCH) #undef DISPATCH + // TODO: decimal256 case TypeIndex::Decimal128: [[fallthrough]]; case TypeIndex::Decimal128I: diff --git a/be/src/vec/exec/scan/vscan_node.cpp b/be/src/vec/exec/scan/vscan_node.cpp index 87a431fd27bd8f1..7be9f6469e55034 100644 --- a/be/src/vec/exec/scan/vscan_node.cpp +++ b/be/src/vec/exec/scan/vscan_node.cpp @@ -401,6 +401,7 @@ Status VScanNode::_normalize_conjuncts() { M(DECIMAL32) \ M(DECIMAL64) \ M(DECIMAL128I) \ + M(DECIMAL256) \ M(DECIMALV2) \ M(BOOLEAN) APPLY_FOR_PRIMITIVE_TYPE(M) @@ -1219,7 +1220,8 @@ Status VScanNode::_change_value_range(ColumnValueRange& temp_rang (PrimitiveType == TYPE_SMALLINT) || (PrimitiveType == TYPE_INT) || (PrimitiveType == TYPE_BIGINT) || (PrimitiveType == TYPE_LARGEINT) || (PrimitiveType == TYPE_DECIMAL32) || (PrimitiveType == TYPE_DECIMAL64) || - (PrimitiveType == TYPE_DECIMAL128I) || (PrimitiveType == TYPE_STRING) || + (PrimitiveType == TYPE_DECIMAL128I) || + (PrimitiveType == TYPE_DECIMAL256) || (PrimitiveType == TYPE_STRING) || (PrimitiveType == TYPE_BOOLEAN) || (PrimitiveType == TYPE_DATEV2)) { if constexpr (IsFixed) { func(temp_range, diff --git a/be/src/vec/exec/vjdbc_connector.cpp b/be/src/vec/exec/vjdbc_connector.cpp index bb50e21163cb0ee..08e57f54325ae8f 100644 --- a/be/src/vec/exec/vjdbc_connector.cpp +++ b/be/src/vec/exec/vjdbc_connector.cpp @@ -335,7 +335,8 @@ Status JdbcConnector::_check_type(SlotDescriptor* slot_desc, const std::string& case TYPE_DECIMALV2: case TYPE_DECIMAL32: case TYPE_DECIMAL64: - case TYPE_DECIMAL128I: { + case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: { if (type_str != "java.math.BigDecimal") { return Status::InternalError(error_msg); } diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 1f80d8c8393b810..27e659eeb0cc0d1 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -202,7 +202,7 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, } else { _function = AggregateFunctionSimpleFactory::instance().get( _fn.name.function_name, argument_types, _data_type->is_nullable(), - state->be_exec_version()); + state->be_exec_version(), state->enable_decima256()); } if (_function == nullptr) { return Status::InternalError("Agg Function {} is not implemented", _fn.signature); diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp index b483642f9cb536b..109ef5e77d324b9 100644 --- a/be/src/vec/exprs/vexpr.cpp +++ b/be/src/vec/exprs/vexpr.cpp @@ -125,6 +125,11 @@ TExprNode create_texpr_node_from(const void* data, const PrimitiveType& type, in create_texpr_literal_node(data, &node, precision, scale)); break; } + case TYPE_DECIMAL256: { + static_cast( + create_texpr_literal_node(data, &node, precision, scale)); + break; + } case TYPE_CHAR: { static_cast(create_texpr_literal_node(data, &node)); break; diff --git a/be/src/vec/exprs/vexpr.h b/be/src/vec/exprs/vexpr.h index 4bd891790687f1c..84b648c2fb7fdd8 100644 --- a/be/src/vec/exprs/vexpr.h +++ b/be/src/vec/exprs/vexpr.h @@ -39,6 +39,7 @@ #include "vec/columns/column.h" #include "vec/core/block.h" #include "vec/core/column_with_type_and_name.h" +#include "vec/core/wide_integer.h" #include "vec/data_types/data_type.h" #include "vec/exprs/vexpr_fwd.h" #include "vec/functions/function.h" @@ -365,6 +366,13 @@ Status create_texpr_literal_node(const void* data, TExprNode* node, int precisio decimal_literal.__set_value(origin_value->to_string(scale)); (*node).__set_decimal_literal(decimal_literal); (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMAL128I, precision, scale)); + } else if constexpr (T == TYPE_DECIMAL256) { + auto origin_value = reinterpret_cast*>(data); + (*node).__set_node_type(TExprNodeType::DECIMAL_LITERAL); + TDecimalLiteral decimal_literal; + decimal_literal.__set_value(origin_value->to_string(scale)); + (*node).__set_decimal_literal(decimal_literal); + (*node).__set_type(create_type_desc(PrimitiveType::TYPE_DECIMAL256, precision, scale)); } else if constexpr (T == TYPE_FLOAT) { auto origin_value = reinterpret_cast(data); (*node).__set_node_type(TExprNodeType::FLOAT_LITERAL); diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp b/be/src/vec/functions/array/function_array_aggregation.cpp index 102acb9d3941da2..be8373fdc6c27ab 100644 --- a/be/src/vec/functions/array/function_array_aggregation.cpp +++ b/be/src/vec/functions/array/function_array_aggregation.cpp @@ -169,6 +169,7 @@ struct ArrayAggregateImpl { execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || + // execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || diff --git a/be/src/vec/functions/array/function_array_apply.cpp b/be/src/vec/functions/array/function_array_apply.cpp index d05ba904f3085fc..0e9076e65e50b41 100644 --- a/be/src/vec/functions/array/function_array_apply.cpp +++ b/be/src/vec/functions/array/function_array_apply.cpp @@ -210,6 +210,8 @@ class FunctionArrayApply : public IFunction { *dst = _apply_internal(src_column, src_offsets, cmp); \ } else if (which.is_decimal128i()) { \ *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal256()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ } else { \ LOG(FATAL) << "unsupported type " << nested_type->get_name(); \ } \ diff --git a/be/src/vec/functions/array/function_array_difference.h b/be/src/vec/functions/array/function_array_difference.h index 492d7cdcb772527..e19049148680a73 100644 --- a/be/src/vec/functions/array/function_array_difference.h +++ b/be/src/vec/functions/array/function_array_difference.h @@ -229,6 +229,9 @@ class FunctionArrayDifference : public IFunction { } else if (check_column(*nested_column)) { res = _execute_number_expanded(offsets, *nested_column, nested_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number_expanded(offsets, *nested_column, + nested_null_map); } else if (check_column(*nested_column)) { res = _execute_number_expanded(offsets, *nested_column, nested_null_map); diff --git a/be/src/vec/functions/array/function_array_distinct.h b/be/src/vec/functions/array/function_array_distinct.h index dff894d11511de4..7e5e0a73729299b 100644 --- a/be/src/vec/functions/array/function_array_distinct.h +++ b/be/src/vec/functions/array/function_array_distinct.h @@ -304,6 +304,9 @@ class FunctionArrayDistinct : public IFunction { } else if (which.is_decimal128i()) { res = _execute_number(src_column, src_offsets, dest_column, dest_offsets, src_null_map, dest_null_map); + } else if (which.is_decimal256()) { + res = _execute_number(src_column, src_offsets, dest_column, + dest_offsets, src_null_map, dest_null_map); } else if (which.is_decimal128()) { res = _execute_number(src_column, src_offsets, dest_column, dest_offsets, src_null_map, dest_null_map); diff --git a/be/src/vec/functions/array/function_array_element.h b/be/src/vec/functions/array/function_array_element.h index fcd54c6dfaabb7e..5ed50fe3d89cc1d 100644 --- a/be/src/vec/functions/array/function_array_element.h +++ b/be/src/vec/functions/array/function_array_element.h @@ -390,6 +390,9 @@ class FunctionArrayElement : public IFunction { } else if (check_column(*nested_column)) { res = _execute_number(offsets, *nested_column, src_null_map, *idx_col, nested_null_map, dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, *idx_col, + nested_null_map, dst_null_map); } else if (check_column(*nested_column)) { res = _execute_number(offsets, *nested_column, src_null_map, *idx_col, nested_null_map, dst_null_map); diff --git a/be/src/vec/functions/array/function_array_enumerate_uniq.cpp b/be/src/vec/functions/array/function_array_enumerate_uniq.cpp index ead65b0e1e2620e..f3bbf3c57a81651 100644 --- a/be/src/vec/functions/array/function_array_enumerate_uniq.cpp +++ b/be/src/vec/functions/array/function_array_enumerate_uniq.cpp @@ -198,6 +198,8 @@ class FunctionArrayEnumerateUniq : public IFunction { _execute_number(data_columns, *offsets, null_map, dst_values); } else if (which.is_decimal128i()) { _execute_number(data_columns, *offsets, null_map, dst_values); + } else if (which.is_decimal256()) { + _execute_number(data_columns, *offsets, null_map, dst_values); } else if (which.is_date_time_v2()) { _execute_number(data_columns, *offsets, null_map, dst_values); } else if (which.is_decimal128()) { diff --git a/be/src/vec/functions/array/function_array_index.h b/be/src/vec/functions/array/function_array_index.h index 27aaa24cdad9313..7644690b7427af0 100644 --- a/be/src/vec/functions/array/function_array_index.h +++ b/be/src/vec/functions/array/function_array_index.h @@ -351,6 +351,10 @@ class FunctionArrayIndex : public IFunction { return_column = _execute_number_expanded( offsets, nested_null_map, *nested_column, *right_column, right_nested_null_map, array_null_map); + } else if (check_column(*nested_column)) { + return_column = _execute_number_expanded( + offsets, nested_null_map, *nested_column, *right_column, + right_nested_null_map, array_null_map); } else if (check_column(*nested_column)) { return_column = _execute_number_expanded( offsets, nested_null_map, *nested_column, *right_column, diff --git a/be/src/vec/functions/array/function_array_join.h b/be/src/vec/functions/array/function_array_join.h index d822c45a41d5f66..b9829649148a0bf 100644 --- a/be/src/vec/functions/array/function_array_join.h +++ b/be/src/vec/functions/array/function_array_join.h @@ -251,6 +251,9 @@ struct ArrayJoinImpl { res = _execute_number(src_column, src_offsets, src_null_map, sep_str, null_replace_str, nested_type, dest_column_ptr); + } else if (which.is_decimal256()) { + res = _execute_number(src_column, src_offsets, src_null_map, sep_str, + null_replace_str, nested_type, dest_column_ptr); } else if (which.is_decimal128()) { res = _execute_number(src_column, src_offsets, src_null_map, sep_str, null_replace_str, nested_type, dest_column_ptr); diff --git a/be/src/vec/functions/array/function_array_remove.h b/be/src/vec/functions/array/function_array_remove.h index eb33952a2dbfa7a..3eae6a25064feed 100644 --- a/be/src/vec/functions/array/function_array_remove.h +++ b/be/src/vec/functions/array/function_array_remove.h @@ -324,6 +324,9 @@ class FunctionArrayRemove : public IFunction { } else if (check_column(*nested_column)) { res = _execute_number_expanded(offsets, *nested_column, *right_column, nested_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number_expanded(offsets, *nested_column, + *right_column, nested_null_map); } else if (check_column(*nested_column)) { res = _execute_number_expanded(offsets, *nested_column, *right_column, nested_null_map); diff --git a/be/src/vec/functions/array/function_arrays_overlap.h b/be/src/vec/functions/array/function_arrays_overlap.h index 7af722e10a5cd01..bad3aa047d3f324 100644 --- a/be/src/vec/functions/array/function_arrays_overlap.h +++ b/be/src/vec/functions/array/function_arrays_overlap.h @@ -226,6 +226,10 @@ class FunctionArraysOverlap : public IFunction { ret = _execute_internal(left_exec_data, right_exec_data, dst_null_map_data, dst_nested_col->get_data().data()); + } else if (check_column(*left_exec_data.nested_col)) { + ret = _execute_internal(left_exec_data, right_exec_data, + dst_null_map_data, + dst_nested_col->get_data().data()); } else if (check_column(*left_exec_data.nested_col)) { ret = _execute_internal(left_exec_data, right_exec_data, dst_null_map_data, diff --git a/be/src/vec/functions/function.h b/be/src/vec/functions/function.h index aeea5d1df048e5c..5f481aeef399a7f 100644 --- a/be/src/vec/functions/function.h +++ b/be/src/vec/functions/function.h @@ -682,11 +682,12 @@ ColumnPtr wrap_in_nullable(const ColumnPtr& src, const Block& block, const Colum M(Float32, ColumnFloat32) \ M(Float64, ColumnFloat64) -#define DECIMAL_TYPE_TO_COLUMN_TYPE(M) \ - M(Decimal32, ColumnDecimal) \ - M(Decimal64, ColumnDecimal) \ - M(Decimal128, ColumnDecimal) \ - M(Decimal128I, ColumnDecimal) +#define DECIMAL_TYPE_TO_COLUMN_TYPE(M) \ + M(Decimal32, ColumnDecimal) \ + M(Decimal64, ColumnDecimal) \ + M(Decimal128, ColumnDecimal) \ + M(Decimal128I, ColumnDecimal) \ + M(Decimal256, ColumnDecimal) #define STRING_TYPE_TO_COLUMN_TYPE(M) \ M(String, ColumnString) \ diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 122d5b01e90117d..10fffd9db8b4a41 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -79,7 +79,7 @@ struct OperationTraits { std::is_same_v>; static constexpr bool can_overflow = (is_plus_minus || is_multiply) && - (IsDecimalV2 || IsDecimalV2 || IsDecimal128I || IsDecimal128I); + (IsDecimalV2 || IsDecimalV2 || IsDecimal256 || IsDecimal256); static constexpr bool has_variadic_argument = !std::is_void_v()))>; }; @@ -239,7 +239,7 @@ struct DecimalBinaryOperation { Op::vector_vector(a, b, c, size); } else { for (size_t i = 0; i < size; i++) { - c[i] = apply(a[i], b[i]); + c[i] = typename ArrayC::value_type(apply(a[i], b[i])); } } } @@ -251,11 +251,20 @@ struct DecimalBinaryOperation { if constexpr (IsDecimalV2 || IsDecimalV2) { /// default: use it if no return before for (size_t i = 0; i < size; ++i) { - c[i] = apply(a[i], b[i], null_map[i]); + c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i])); } } else if constexpr (OpTraits::is_division && (IsDecimalNumber || IsDecimalNumber)) { for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a[i], b[i], null_map[i]); + if constexpr (IsDecimalNumber && IsDecimalNumber) { + c[i] = typename ArrayC::value_type( + apply_scaled_div(a[i].value, b[i].value, null_map[i])); + } else if constexpr (IsDecimalNumber) { + c[i] = typename ArrayC::value_type( + apply_scaled_div(a[i].value, b[i], null_map[i])); + } else { + c[i] = typename ArrayC::value_type( + apply_scaled_div(a[i], b[i].value, null_map[i])); + } } } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && (IsDecimalNumber || IsDecimalNumber)) { @@ -264,7 +273,7 @@ struct DecimalBinaryOperation { } } else { for (size_t i = 0; i < size; ++i) { - c[i] = apply(a[i], b[i], null_map[i]); + c[i] = typename ArrayC::value_type(apply(a[i], b[i], null_map[i])); } } } @@ -273,14 +282,14 @@ struct DecimalBinaryOperation { typename ArrayC::value_type* c, size_t size) { if constexpr (OpTraits::is_division && IsDecimalNumber) { for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a[i], b); + c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b)); } return; } /// default: use it if no return before for (size_t i = 0; i < size; ++i) { - c[i] = apply(a[i], b); + c[i] = typename ArrayC::value_type(apply(a[i], b)); } } @@ -288,7 +297,7 @@ struct DecimalBinaryOperation { typename ArrayC::value_type* c, NullMap& null_map, size_t size) { if constexpr (OpTraits::is_division && IsDecimalNumber) { for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a[i], b, null_map[i]); + c[i] = typename ArrayC::value_type(apply_scaled_div(a[i], b.value, null_map[i])); } } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && (IsDecimalNumber || IsDecimalNumber)) { @@ -297,7 +306,7 @@ struct DecimalBinaryOperation { } } else { for (size_t i = 0; i < size; ++i) { - c[i] = apply(a[i], b, null_map[i]); + c[i] = typename ArrayC::value_type(apply(a[i], b, null_map[i])); } } } @@ -307,11 +316,12 @@ struct DecimalBinaryOperation { if constexpr (IsDecimalV2 || IsDecimalV2) { DecimalV2Value da(a); for (size_t i = 0; i < size; ++i) { - c[i] = Op::template apply(da, DecimalV2Value(b[i])).value(); + c[i] = typename ArrayC::value_type( + Op::template apply(da, DecimalV2Value(b[i])).value()); } } else { for (size_t i = 0; i < size; ++i) { - c[i] = apply(a, b[i]); + c[i] = typename ArrayC::value_type(apply(a, b[i])); } } } @@ -320,7 +330,7 @@ struct DecimalBinaryOperation { typename ArrayC::value_type* c, NullMap& null_map, size_t size) { if constexpr (OpTraits::is_division && IsDecimalNumber) { for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a, b[i], null_map[i]); + c[i] = typename ArrayC::value_type(apply_scaled_div(a, b[i].value, null_map[i])); } } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && (IsDecimalNumber || IsDecimalNumber)) { @@ -329,23 +339,27 @@ struct DecimalBinaryOperation { } } else { for (size_t i = 0; i < size; ++i) { - c[i] = apply(a, b[i], null_map[i]); + c[i] = typename ArrayC::value_type(apply(a, b[i], null_map[i])); } } } - static ResultType constant_constant(A a, B b) { return apply(a, b); } + static ResultType constant_constant(A a, B b) { return ResultType(apply(a, b)); } static ResultType constant_constant(A a, B b, UInt8& is_null) { if constexpr (OpTraits::is_division && IsDecimalNumber) { - return apply_scaled_div(a, b, is_null); + if constexpr (IsDecimalNumber) { + return ResultType(apply_scaled_div(a.value, b.value, is_null)); + } else { + return ResultType(apply_scaled_div(a, b.value, is_null)); + } } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && (IsDecimalNumber || IsDecimalNumber)) { NativeResultType res; is_null = apply_op_safely(a, b, res); - return res; + return ResultType(res); } else { - return apply(a, b, is_null); + return ResultType(apply(a, b, is_null)); } } @@ -459,7 +473,7 @@ struct DecimalBinaryOperation { NativeResultType res; // TODO handle overflow gracefully if (Op::template apply(a, b, res)) { - res = type_limit::max(); + res = type_limit::max().value; } return res; } else { @@ -559,6 +573,15 @@ inline constexpr bool IsIntegral = true; template constexpr bool UseLeftDecimal = false; template <> +inline constexpr bool UseLeftDecimal, DataTypeDecimal> = + true; +template <> +inline constexpr bool UseLeftDecimal, DataTypeDecimal> = + true; +template <> +inline constexpr bool UseLeftDecimal, DataTypeDecimal> = + true; +template <> inline constexpr bool UseLeftDecimal, DataTypeDecimal> = true; template <> @@ -725,8 +748,8 @@ class FunctionBinaryArithmetic : public IFunction { return cast_type_to_either, DataTypeDecimal, - DataTypeDecimal, DataTypeDecimal>( - type, std::forward(f)); + DataTypeDecimal, DataTypeDecimal, + DataTypeDecimal>(type, std::forward(f)); } template diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h index 5f0c7d2a3df4af9..8b59fcf8aa335e8 100644 --- a/be/src/vec/functions/function_cast.h +++ b/be/src/vec/functions/function_cast.h @@ -817,6 +817,9 @@ struct NameToDecimal128 { struct NameToDecimal128I { static constexpr auto name = "toDecimal128I"; }; +struct NameToDecimal256 { + static constexpr auto name = "toDecimal256"; +}; struct NameToUInt8 { static constexpr auto name = "toUInt8"; }; @@ -930,6 +933,12 @@ StringParser::ParseResult try_parse_decimal_impl(typename DataType::FieldType& x UInt32 precision = ((PrecisionScaleArg)additions).precision; return try_read_decimal_text(x, rb, precision, scale); } + + if constexpr (IsDataTypeDecimal256) { + UInt32 scale = ((PrecisionScaleArg)additions).scale; + UInt32 precision = ((PrecisionScaleArg)additions).precision; + return try_read_decimal_text(x, rb, precision, scale); + } } /// Monotonicity. @@ -1096,7 +1105,8 @@ class FunctionConvert : public IFunction { static constexpr auto name = Name::name; static constexpr bool to_decimal = std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; + std::is_same_v || std::is_same_v || + std::is_same_v; static FunctionPtr create() { return std::make_shared(); } @@ -1203,6 +1213,8 @@ using FunctionToDecimal128 = FunctionConvert, NameToDecimal128, UnknownMonotonicity>; using FunctionToDecimal128I = FunctionConvert, NameToDecimal128I, UnknownMonotonicity>; +using FunctionToDecimal256 = + FunctionConvert, NameToDecimal256, UnknownMonotonicity>; using FunctionToDate = FunctionConvert; using FunctionToDateTime = FunctionConvert; using FunctionToDateV2 = FunctionConvert; @@ -1273,6 +1285,10 @@ struct FunctionTo> { using Type = FunctionToDecimal128I; }; template <> +struct FunctionTo> { + using Type = FunctionToDecimal256; +}; +template <> struct FunctionTo { using Type = FunctionToDate; }; @@ -1430,6 +1446,9 @@ struct ConvertImpl, Name> template struct ConvertImpl, Name> : ConvertThroughParsing, Name> {}; +template +struct ConvertImpl, Name> + : ConvertThroughParsing, Name> {}; template class FunctionConvertFromString : public IFunction { @@ -2093,7 +2112,8 @@ class FunctionCast final : public IFunctionBase { if constexpr (std::is_same_v> || std::is_same_v> || std::is_same_v> || - std::is_same_v>) { + std::is_same_v> || + std::is_same_v>) { ret = create_decimal_wrapper(from_type, check_and_get_data_type(to_type.get())); return true; diff --git a/be/src/vec/functions/function_multi_same_args.h b/be/src/vec/functions/function_multi_same_args.h index aaac717331ce3b2..0c45c7cd446ef7a 100644 --- a/be/src/vec/functions/function_multi_same_args.h +++ b/be/src/vec/functions/function_multi_same_args.h @@ -18,7 +18,6 @@ #pragma once #include "udf/udf.h" -#include "vec/data_types/get_least_supertype.h" #include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" #include "vec/utils/template_helpers.hpp" diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 69c2e3a528678a2..aa76dafd20a8785 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -2423,12 +2423,33 @@ struct MoneyFormatDecimalImpl { frac_part = frac_part * multiplier; } - StringRef str = MoneyFormat::do_money_format( + StringRef str = MoneyFormat::do_money_format<__int128, 53>( context, decimal128_column->get_whole_part(i), frac_part); result_column->insert_data(str.data, str.size); } } + // TODO: decimal256 + /* else if (auto* decimal256_column = + check_and_get_column>(*col_ptr)) { + const UInt32 scale = decimal256_column->get_scale(); + const auto multiplier = + scale > 2 ? common::exp10_i32(scale - 2) : common::exp10_i32(2 - scale); + for (size_t i = 0; i < input_rows_count; i++) { + Decimal256 frac_part = decimal256_column->get_fractional_part(i); + if (scale > 2) { + int delta = ((frac_part % multiplier) << 1) > multiplier; + frac_part = Decimal256(frac_part / multiplier + delta); + } else if (scale < 2) { + frac_part = Decimal256(frac_part * multiplier); + } + + StringRef str = MoneyFormat::do_money_format( + context, decimal256_column->get_whole_part(i), frac_part); + + result_column->insert_data(str.data, str.size); + } + }*/ } }; diff --git a/be/src/vec/functions/function_unary_arithmetic.h b/be/src/vec/functions/function_unary_arithmetic.h index 63376f93d7395eb..51aabb9a803aeb8 100644 --- a/be/src/vec/functions/function_unary_arithmetic.h +++ b/be/src/vec/functions/function_unary_arithmetic.h @@ -72,8 +72,8 @@ class FunctionUnaryArithmetic : public IFunction { DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, DataTypeInt128, DataTypeFloat32, DataTypeFloat64, DataTypeDecimal, DataTypeDecimal, - DataTypeDecimal, DataTypeDecimal>( - type, std::forward(f)); + DataTypeDecimal, DataTypeDecimal, + DataTypeDecimal>(type, std::forward(f)); } public: diff --git a/be/src/vec/functions/function_width_bucket.cpp b/be/src/vec/functions/function_width_bucket.cpp index 1daf3ed5eaa033a..40c08a950e25881 100644 --- a/be/src/vec/functions/function_width_bucket.cpp +++ b/be/src/vec/functions/function_width_bucket.cpp @@ -148,6 +148,9 @@ class FunctionWidthBucket : public IFunction { } else if (which.is_decimal128i()) { _execute(expr_column, min_value_column, max_value_column, num_buckets, nested_column_column); + } else if (which.is_decimal256()) { + _execute(expr_column, min_value_column, max_value_column, num_buckets, + nested_column_column); } else if (which.is_date()) { _execute(expr_column, min_value_column, max_value_column, num_buckets, nested_column_column); diff --git a/be/src/vec/functions/functions_comparison.h b/be/src/vec/functions/functions_comparison.h index 0bf03310b9f85e1..3b58f21c406d9fe 100644 --- a/be/src/vec/functions/functions_comparison.h +++ b/be/src/vec/functions/functions_comparison.h @@ -35,7 +35,6 @@ #include "vec/core/decimal_comparison.h" #include "vec/data_types/data_type_number.h" #include "vec/data_types/data_type_string.h" -#include "vec/data_types/get_least_supertype.h" #include "vec/functions/function.h" #include "vec/functions/function_helpers.h" #include "vec/functions/functions_logical.h" diff --git a/be/src/vec/functions/if.cpp b/be/src/vec/functions/if.cpp index 1664c0719ecf314..9b14abce2ae0a17 100644 --- a/be/src/vec/functions/if.cpp +++ b/be/src/vec/functions/if.cpp @@ -46,7 +46,6 @@ #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" -#include "vec/data_types/get_least_supertype.h" #include "vec/functions/function.h" #include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" diff --git a/be/src/vec/functions/least_greast.cpp b/be/src/vec/functions/least_greast.cpp index 90f8fa99cf7dc86..be35504d83870e3 100644 --- a/be/src/vec/functions/least_greast.cpp +++ b/be/src/vec/functions/least_greast.cpp @@ -138,7 +138,8 @@ struct CompareMultiImpl { } } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { for (size_t i = 0; i < input_rows_count; ++i) { using type = std::decay_t; result_raw_data[i] = @@ -243,7 +244,8 @@ struct FunctionFieldImpl { } } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { for (size_t i = 0; i < input_rows_count; ++i) { using type = std::decay_t; res_data[i] |= (!res_data[i] * diff --git a/be/src/vec/olap/olap_data_convertor.cpp b/be/src/vec/olap/olap_data_convertor.cpp index 9f3fe2b7ac48282..0870f89d671139a 100644 --- a/be/src/vec/olap/olap_data_convertor.cpp +++ b/be/src/vec/olap/olap_data_convertor.cpp @@ -131,6 +131,9 @@ OlapBlockDataConvertor::create_olap_column_data_convertor(const TabletColumn& co case FieldType::OLAP_FIELD_TYPE_DECIMAL128I: { return std::make_unique>(); } + case FieldType::OLAP_FIELD_TYPE_DECIMAL256: { + return std::make_unique>(); + } case FieldType::OLAP_FIELD_TYPE_JSONB: { return std::make_unique(true); } diff --git a/be/src/vec/sink/vtablet_block_convertor.cpp b/be/src/vec/sink/vtablet_block_convertor.cpp index 436eb3639de8521..fb5ed14873acc23 100644 --- a/be/src/vec/sink/vtablet_block_convertor.cpp +++ b/be/src/vec/sink/vtablet_block_convertor.cpp @@ -137,14 +137,16 @@ DecimalType OlapTableBlockConvertor::_get_decimalv3_min_or_max(const TypeDescrip pmap = IsMin ? &_min_decimal32_val : &_max_decimal32_val; } else if constexpr (std::is_same_v) { pmap = IsMin ? &_min_decimal64_val : &_max_decimal64_val; - } else { + } else if constexpr (std::is_same_v) { pmap = IsMin ? &_min_decimal128_val : &_max_decimal128_val; + } else { + pmap = IsMin ? &_min_decimal256_val : &_max_decimal256_val; } // found auto iter = pmap->find(type.precision); if (iter != pmap->end()) { - return iter->second; + return DecimalType(iter->second); } typename DecimalType::NativeType value; @@ -154,7 +156,7 @@ DecimalType OlapTableBlockConvertor::_get_decimalv3_min_or_max(const TypeDescrip value = vectorized::max_decimal_value(type.precision); } pmap->emplace(type.precision, value); - return value; + return DecimalType(value); } Status OlapTableBlockConvertor::_validate_column(RuntimeState* state, const TypeDescriptor& type, @@ -336,6 +338,10 @@ Status OlapTableBlockConvertor::_validate_column(RuntimeState* state, const Type CHECK_VALIDATION_FOR_DECIMALV3(vectorized::Decimal128I); break; } + case TYPE_DECIMAL256: { + CHECK_VALIDATION_FOR_DECIMALV3(vectorized::Decimal256); + break; + } #undef CHECK_VALIDATION_FOR_DECIMALV3 case TYPE_ARRAY: { const auto* column_array = diff --git a/be/src/vec/sink/vtablet_block_convertor.h b/be/src/vec/sink/vtablet_block_convertor.h index 27440c628be4bba..3ee3d582653d71a 100644 --- a/be/src/vec/sink/vtablet_block_convertor.h +++ b/be/src/vec/sink/vtablet_block_convertor.h @@ -93,6 +93,8 @@ class OlapTableBlockConvertor { std::map _min_decimal64_val; std::map _max_decimal128_val; std::map _min_decimal128_val; + std::map _max_decimal256_val; + std::map _min_decimal256_val; std::vector _filter_map; diff --git a/be/src/vec/sink/writer/vmysql_table_writer.cpp b/be/src/vec/sink/writer/vmysql_table_writer.cpp index 77e054e05c9c4a6..2de4dd5ac30f369 100644 --- a/be/src/vec/sink/writer/vmysql_table_writer.cpp +++ b/be/src/vec/sink/writer/vmysql_table_writer.cpp @@ -205,7 +205,8 @@ Status VMysqlTableWriter::_insert_row(vectorized::Block& block, size_t row) { } case TYPE_DECIMAL32: case TYPE_DECIMAL64: - case TYPE_DECIMAL128I: { + case TYPE_DECIMAL128I: + case TYPE_DECIMAL256: { auto val = type_ptr->to_string(*column, row); fmt::format_to(_insert_stmt_buffer, "{}", val); break; diff --git a/be/test/vec/data_types/decimal_test.cpp b/be/test/vec/data_types/decimal_test.cpp new file mode 100644 index 000000000000000..74a65dd2b246ad2 --- /dev/null +++ b/be/test/vec/data_types/decimal_test.cpp @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "gtest/gtest_pred_impl.h" +#include "runtime/type_limit.h" +#include "vec/core/types.h" +namespace doris::vectorized { + +TEST(DecimalTest, Decimal256) { + // 9999999999999999999999999999999999999999999999999999999999999999999999999999 + Decimal256 dec1(type_limit::max()); + auto des_str = dec1.to_string(10); + EXPECT_EQ(des_str, + "999999999999999999999999999999999999999999999999999999999999999999.9999999999"); + des_str = dec1.to_string(0); + EXPECT_EQ(des_str, + "9999999999999999999999999999999999999999999999999999999999999999999999999999"); + des_str = dec1.to_string(76); + EXPECT_EQ(des_str, + "0.9999999999999999999999999999999999999999999999999999999999999999999999999999"); + + auto dec2 = type_limit::min(); + des_str = dec2.to_string(10); + EXPECT_EQ(des_str, + "-999999999999999999999999999999999999999999999999999999999999999999.9999999999"); + des_str = dec2.to_string(0); + EXPECT_EQ(des_str, + "-9999999999999999999999999999999999999999999999999999999999999999999999999999"); + des_str = dec2.to_string(76); + EXPECT_EQ(des_str, + "-0.9999999999999999999999999999999999999999999999999999999999999999999999999999"); + + // plus + Decimal256 dec3 = dec1 + dec2; + des_str = dec3.to_string(10); + EXPECT_EQ(des_str, "0.0000000000"); + des_str = dec3.to_string(0); + EXPECT_EQ(des_str, "0"); + des_str = dec3.to_string(76); + EXPECT_EQ(des_str, + "0.0000000000000000000000000000000000000000000000000000000000000000000000000000"); + + // minus + dec2 = type_limit::max(); + dec3 = dec1 - dec2; + des_str = dec3.to_string(10); + EXPECT_EQ(des_str, "0.0000000000"); + + // multiply + + // divide + dec1 = type_limit::max(); + dec2 = vectorized::Decimal256(10); + dec3 = dec1 / dec2; + des_str = dec3.to_string(1); + EXPECT_EQ(des_str, + "99999999999999999999999999999999999999999999999999999999999999999999999999.9"); + + // overflow +} +} // namespace doris::vectorized \ No newline at end of file diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/PrimitiveType.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/PrimitiveType.java index 78a60239f2ad7f2..e5a5726e52e8c70 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/PrimitiveType.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/PrimitiveType.java @@ -54,6 +54,7 @@ public enum PrimitiveType { DECIMAL32("DECIMAL32", 4, TPrimitiveType.DECIMAL32, true), DECIMAL64("DECIMAL64", 8, TPrimitiveType.DECIMAL64, true), DECIMAL128("DECIMAL128", 16, TPrimitiveType.DECIMAL128I, true), + DECIMAL256("DECIMAL256", 32, TPrimitiveType.DECIMAL256, true), TIME("TIME", 8, TPrimitiveType.TIME, false), // these following types are stored as object binary in BE. HLL("HLL", 16, TPrimitiveType.HLL, true), @@ -94,6 +95,7 @@ public enum PrimitiveType { builder.add(DECIMAL32); builder.add(DECIMAL64); builder.add(DECIMAL128); + builder.add(DECIMAL256); builder.add(DATETIMEV2); typeWithPrecision = builder.build(); } @@ -575,6 +577,7 @@ public static ImmutableSetMultimap getImplicitCast numericTypes.add(DECIMAL32); numericTypes.add(DECIMAL64); numericTypes.add(DECIMAL128); + numericTypes.add(DECIMAL256); supportedTypes = Lists.newArrayList(); supportedTypes.add(NULL_TYPE); @@ -602,6 +605,7 @@ public static ImmutableSetMultimap getImplicitCast supportedTypes.add(DECIMAL32); supportedTypes.add(DECIMAL64); supportedTypes.add(DECIMAL128); + supportedTypes.add(DECIMAL256); supportedTypes.add(BITMAP); supportedTypes.add(ARRAY); supportedTypes.add(MAP); @@ -685,6 +689,8 @@ public static PrimitiveType fromThrift(TPrimitiveType tPrimitiveType) { return DECIMAL64; case DECIMAL128I: return DECIMAL128; + case DECIMAL256: + return DECIMAL256; case TIME: return TIME; case TIMEV2: @@ -767,7 +773,7 @@ public boolean isDecimalV2Type() { } public boolean isDecimalV3Type() { - return this == DECIMAL32 || this == DECIMAL64 || this == DECIMAL128; + return this == DECIMAL32 || this == DECIMAL64 || this == DECIMAL128 || this == DECIMAL256; } public boolean isNumericType() { @@ -876,6 +882,7 @@ public MysqlColType toMysqlType() { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: return MysqlColType.MYSQL_TYPE_NEWDECIMAL; case STRING: return MysqlColType.MYSQL_TYPE_BLOB; @@ -913,6 +920,8 @@ public int getOlapColumnIndexSize() { return 8; case DECIMAL128: return 16; + case DECIMAL256: + return 32; default: return this.getSlotSize(); } diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java index 540f8821f5ce4de..618aecb996966be 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java @@ -78,6 +78,7 @@ public class ScalarType extends Type { public static final int MAX_DECIMAL32_PRECISION = 9; public static final int MAX_DECIMAL64_PRECISION = 18; public static final int MAX_DECIMAL128_PRECISION = 38; + public static final int MAX_DECIMAL256_PRECISION = 76; public static final int DEFAULT_MIN_AVG_DECIMAL128_SCALE = 4; public static final int MAX_DATETIMEV2_SCALE = 6; @@ -923,53 +924,6 @@ public boolean equals(Object o) { return true; } - public Type getMaxResolutionType() { - if (isIntegerType()) { - return ScalarType.BIGINT; - // Timestamps get summed as DOUBLE for AVG. - } else if (isFloatingPointType()) { - return ScalarType.DOUBLE; - } else if (isNull()) { - return ScalarType.NULL; - } else if (isDecimalV2()) { - return createDecimalTypeInternal(MAX_PRECISION, scale, true); - } else if (getPrimitiveType() == PrimitiveType.DECIMAL32) { - return createDecimalTypeInternal(MAX_DECIMAL32_PRECISION, scale, false); - } else if (getPrimitiveType() == PrimitiveType.DECIMAL64) { - return createDecimalTypeInternal(MAX_DECIMAL64_PRECISION, scale, false); - } else if (getPrimitiveType() == PrimitiveType.DECIMAL128) { - return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale, false); - } else if (isLargeIntType()) { - return ScalarType.LARGEINT; - } else if (isDatetimeV2()) { - return createDatetimeV2Type(6); - } else if (isTimeV2()) { - return createTimeV2Type(6); - } else { - return ScalarType.INVALID; - } - } - - public ScalarType getNextResolutionType() { - Preconditions.checkState(isNumericType() || isNull()); - if (type == PrimitiveType.DOUBLE || type == PrimitiveType.BIGINT || isNull()) { - return this; - } else if (type == PrimitiveType.DECIMALV2) { - return createDecimalTypeInternal(MAX_PRECISION, scale, true); - } else if (type == PrimitiveType.DECIMAL32) { - return createDecimalTypeInternal(MAX_DECIMAL64_PRECISION, scale, false); - } else if (type == PrimitiveType.DECIMAL64) { - return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale, false); - } else if (type == PrimitiveType.DECIMAL128) { - return createDecimalTypeInternal(MAX_DECIMAL128_PRECISION, scale, false); - } else if (type == PrimitiveType.DATETIMEV2) { - return createDatetimeV2Type(6); - } else if (type == PrimitiveType.TIMEV2) { - return createTimeV2Type(6); - } - return createType(PrimitiveType.values()[type.ordinal() + 1]); - } - /** * Returns the smallest decimal type that can safely store this type. Returns * INVALID if this type cannot be stored as a decimal. diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java index ad498773d472062..d0e44bf27c7eb65 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java @@ -87,6 +87,10 @@ public abstract class Type { public static final ScalarType DEFAULT_DECIMAL128 = ScalarType.createDecimalType(PrimitiveType.DECIMAL128, ScalarType.MAX_DECIMAL128_PRECISION, ScalarType.DEFAULT_SCALE); + + public static final ScalarType DEFAULT_DECIMAL256 = + ScalarType.createDecimalType(PrimitiveType.DECIMAL256, ScalarType.MAX_DECIMAL256_PRECISION, + ScalarType.DEFAULT_SCALE); public static final ScalarType DEFAULT_DECIMALV3 = DEFAULT_DECIMAL32; public static final ScalarType DEFAULT_DATETIMEV2 = ScalarType.createDatetimeV2Type(0); public static final ScalarType DATETIMEV2 = DEFAULT_DATETIMEV2; @@ -96,6 +100,7 @@ public abstract class Type { public static final ScalarType DECIMAL32 = DEFAULT_DECIMAL32; public static final ScalarType DECIMAL64 = DEFAULT_DECIMAL64; public static final ScalarType DECIMAL128 = DEFAULT_DECIMAL128; + public static final ScalarType DECIMAL256 = DEFAULT_DECIMAL256; public static final ScalarType JSONB = new ScalarType(PrimitiveType.JSONB); // (ScalarType) ScalarType.createDecimalTypeInternal(-1, -1); public static final ScalarType DEFAULT_VARCHAR = ScalarType.createVarcharType(-1); @@ -391,7 +396,7 @@ public boolean isDecimalV3OrContainsDecimalV3() { public boolean isDecimalV3() { return isScalarType(PrimitiveType.DECIMAL32) || isScalarType(PrimitiveType.DECIMAL64) - || isScalarType(PrimitiveType.DECIMAL128); + || isScalarType(PrimitiveType.DECIMAL128) || isScalarType(PrimitiveType.DECIMAL256); } public boolean isDatetimeV2() { @@ -975,7 +980,8 @@ protected static Pair fromThrift(TTypeDesc col, int nodeIdx) { scalarType.getScale()); } else if (scalarType.getType() == TPrimitiveType.DECIMAL32 || scalarType.getType() == TPrimitiveType.DECIMAL64 - || scalarType.getType() == TPrimitiveType.DECIMAL128I) { + || scalarType.getType() == TPrimitiveType.DECIMAL128I + || scalarType.getType() == TPrimitiveType.DECIMAL256) { Preconditions.checkState(scalarType.isSetPrecision() && scalarType.isSetScale()); type = ScalarType.createDecimalV3Type(scalarType.getPrecision(), @@ -1130,6 +1136,7 @@ public Integer getPrecision() { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: case DATETIMEV2: case TIMEV2: return t.decimalPrecision(); @@ -1166,6 +1173,7 @@ public Integer getDecimalDigits() { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: return t.decimalScale(); default: return null; @@ -1200,6 +1208,7 @@ public Integer getNumPrecRadix() { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: return 10; default: // everything else (including boolean and string) is null @@ -1267,6 +1276,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; // TINYINT @@ -1288,6 +1298,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[TINYINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; compatibilityMatrix[TINYINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[TINYINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[TINYINT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[TINYINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[TINYINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1315,6 +1326,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[SMALLINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; compatibilityMatrix[SMALLINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[SMALLINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[SMALLINT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[SMALLINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[SMALLINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1345,6 +1357,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[INT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; compatibilityMatrix[INT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[INT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[INT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[INT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[INT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[INT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1376,6 +1389,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[BIGINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[BIGINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[BIGINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[BIGINT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[BIGINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[BIGINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1399,6 +1413,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[LARGEINT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[LARGEINT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[LARGEINT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[LARGEINT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[LARGEINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[LARGEINT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1421,6 +1436,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[FLOAT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[FLOAT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[FLOAT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[FLOAT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[FLOAT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[FLOAT.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1439,6 +1455,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DOUBLE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DOUBLE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DOUBLE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DOUBLE.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DOUBLE.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][TIME.ordinal()] = PrimitiveType.DOUBLE; compatibilityMatrix[DOUBLE.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1461,6 +1478,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DATE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; compatibilityMatrix[DATE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DATE.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DATE.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1481,6 +1499,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DATEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL32; compatibilityMatrix[DATEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DATEV2.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DATEV2.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATEV2.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1500,6 +1519,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DATETIME.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATETIME.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATETIME.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DATETIME.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DATETIME.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1519,6 +1539,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DATETIMEV2.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DATETIMEV2.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DATETIMEV2.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1538,6 +1559,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[CHAR.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1553,6 +1575,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[VARCHAR.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1576,6 +1599,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[STRING.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[STRING.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[STRING.ordinal()][JSONB.ordinal()] = PrimitiveType.STRING; compatibilityMatrix[STRING.ordinal()][VARIANT.ordinal()] = PrimitiveType.STRING; compatibilityMatrix[STRING.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1585,6 +1609,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[JSONB.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[JSONB.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[JSONB.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[JSONB.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[JSONB.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[JSONB.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[JSONB.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1601,6 +1626,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[VARIANT.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARIANT.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARIANT.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[VARIANT.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARIANT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARIANT.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[VARIANT.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1627,6 +1653,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMALV2.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; // DECIMAL32 @@ -1642,6 +1669,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DECIMAL32.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL32.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DECIMAL32.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; // DECIMAL64 @@ -1657,6 +1685,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DECIMAL64.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL64; compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL64.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DECIMAL64.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; // DECIMAL128 @@ -1672,8 +1701,24 @@ public Integer getNumPrecRadix() { compatibilityMatrix[DECIMAL128.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL128; compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL128; + compatibilityMatrix[DECIMAL128.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.DECIMAL256; compatibilityMatrix[DECIMAL128.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + // DECIMAL256 + compatibilityMatrix[DECIMAL256.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][DATEV2.ordinal()] = PrimitiveType.DECIMAL256; + compatibilityMatrix[DECIMAL256.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.DECIMAL256; + compatibilityMatrix[DECIMAL256.ordinal()][BITMAP.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][STRING.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][QUANTILE_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL256.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.DECIMAL256; + compatibilityMatrix[DECIMAL256.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.DECIMAL256; + compatibilityMatrix[DECIMAL256.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.DECIMAL256; + compatibilityMatrix[DECIMAL256.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; + // HLL compatibilityMatrix[HLL.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][TIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1687,6 +1732,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[HLL.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[HLL.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[HLL.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1702,6 +1748,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[BITMAP.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BITMAP.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BITMAP.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[BITMAP.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[BITMAP.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; //QUANTILE_STATE @@ -1713,6 +1760,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[QUANTILE_STATE.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[QUANTILE_STATE.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; //AGG_STATE @@ -1724,6 +1772,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[AGG_STATE.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[AGG_STATE.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[AGG_STATE.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[AGG_STATE.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; // TIME why here not??? compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1731,6 +1780,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[TIME.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIME.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][DATEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][DATETIMEV2.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIME.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; @@ -1740,6 +1790,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[TIMEV2.ordinal()][DECIMAL32.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIMEV2.ordinal()][DECIMAL64.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIMEV2.ordinal()][DECIMAL128.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[TIMEV2.ordinal()][DECIMAL256.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[TIMEV2.ordinal()][AGG_STATE.ordinal()] = PrimitiveType.INVALID_TYPE; // Check all of the necessary entries that should be filled. @@ -1801,6 +1852,8 @@ public Type getResultType() { return DECIMAL64; case DECIMAL128: return DECIMAL128; + case DECIMAL256: + return DECIMAL256; case STRING: return STRING; case JSONB: @@ -1948,11 +2001,6 @@ private static Type getDateComparisonResultType(ScalarType t1, ScalarType t2) { } } - public Type getMaxResolutionType() { - Preconditions.checkState(true, "must implemented"); - return null; - } - public Type getNumResultType() { switch (getPrimitiveType()) { case BOOLEAN: @@ -1984,6 +2032,8 @@ public Type getNumResultType() { return Type.DECIMAL64; case DECIMAL128: return Type.DECIMAL128; + case DECIMAL256: + return Type.DECIMAL256; default: return Type.INVALID; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index 747f948c37ca787..bf5c75c58f02ea7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -514,6 +514,7 @@ public CastExpr rewriteExpr(List parameters, List inputParamsExprs case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: // normal decimal if (targetType.getPrecision() != 0) { newTargetType = targetType; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java index 349be4e45ef6145..0f5ff1e1313969c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ColumnDef.java @@ -500,6 +500,7 @@ public static void validateDefaultValue(Type type, String defaultValue, DefaultV case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: DecimalLiteral decimalLiteral = new DecimalLiteral(defaultValue); decimalLiteral.checkPrecisionAndScale(scalarType.getScalarPrecision(), scalarType.getScalarScale()); break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index d8749ad6cca7703..fe03afd02c3a2eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -2528,6 +2528,8 @@ protected Type getActualScalarType(Type originType) { return Type.DECIMAL64; } else if (originType.getPrimitiveType() == PrimitiveType.DECIMAL128) { return Type.DECIMAL128; + } else if (originType.getPrimitiveType() == PrimitiveType.DECIMAL256) { + return Type.DECIMAL256; } else if (originType.getPrimitiveType() == PrimitiveType.DATETIMEV2) { return Type.DATETIMEV2; } else if (originType.getPrimitiveType() == PrimitiveType.DATEV2) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java index babcc564c0844ed..bf24955970c7b20 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LiteralExpr.java @@ -75,6 +75,7 @@ public static LiteralExpr create(String value, Type type) throws AnalysisExcepti case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: literalExpr = new DecimalLiteral(value); break; case CHAR: diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java index 66747e0002f7fdd..3119781c97cdb82 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java @@ -245,6 +245,7 @@ protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: try { DecimalLiteral res = new DecimalLiteral(new BigDecimal(value).stripTrailingZeros()); res.setType(targetType); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/TypeDef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/TypeDef.java index 9c80ee01ab1f451..845b304db16e606 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/TypeDef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/TypeDef.java @@ -29,6 +29,8 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.thrift.TColumnDesc; import org.apache.doris.thrift.TPrimitiveType; @@ -292,6 +294,33 @@ private void analyzeScalarType(ScalarType scalarType) } break; } + case DECIMAL256: { + SessionVariable sessionVariable = ConnectContext.get().getSessionVariable(); + boolean enableDecimal256 = sessionVariable.enableDecimal256(); + boolean enableNereidsPlanner = sessionVariable.isEnableNereidsPlanner(); + if (enableNereidsPlanner && enableDecimal256) { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + if (precision < 1 || precision > ScalarType.MAX_DECIMAL256_PRECISION) { + throw new AnalysisException("Precision of decimal256 must between 1 and 76." + + " Precision was set to: " + precision + "."); + } + // scale >= 0 + if (scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + " Scale was set to: " + + scale + "."); + } + // scale < precision + if (scale > precision) { + throw new AnalysisException("Scale of decimal must be smaller than precision." + + " Scale is " + scale + " and precision is " + precision); + } + break; + } else { + int precision = scalarType.decimalPrecision(); + throw new AnalysisException("Column of type Decimal256 with precision " + precision + " in not supported."); + } + } case TIMEV2: case DATETIMEV2: { int precision = scalarType.decimalPrecision(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java index c2f6d466f220cf7..882689dbfa8a12d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java @@ -173,6 +173,7 @@ public void analyze() throws AnalysisException { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: case DECIMALV2: if (!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) { typeDefParams.add(scalarType.getScalarPrecisionStr()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java index 6de1b5a9d410d5a..f8d0d3e716f0c6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java @@ -898,6 +898,7 @@ public String getSignatureString(Map typeStringMap) { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: sb.append(String.format(typeStringMap.get(dataType), getPrecision(), getScale())); break; case ARRAY: diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java b/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java index 7fd02926b7d79e4..3b0676118f9616e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/util/Util.java @@ -87,6 +87,7 @@ public class Util { TYPE_STRING_MAP.put(PrimitiveType.DECIMAL32, "decimal(%d, %d)"); TYPE_STRING_MAP.put(PrimitiveType.DECIMAL64, "decimal(%d, %d)"); TYPE_STRING_MAP.put(PrimitiveType.DECIMAL128, "decimal(%d, %d)"); + TYPE_STRING_MAP.put(PrimitiveType.DECIMAL256, "decimal(%d, %d)"); TYPE_STRING_MAP.put(PrimitiveType.HLL, "varchar(%d)"); TYPE_STRING_MAP.put(PrimitiveType.BOOLEAN, "bool"); TYPE_STRING_MAP.put(PrimitiveType.BITMAP, "bitmap"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSerializer.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSerializer.java index 59375ec4b67ed1d..228f3891ce90e58 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSerializer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSerializer.java @@ -265,7 +265,8 @@ private int getMysqlTypeLength(Type type) { case DECIMALV2: case DECIMAL32: case DECIMAL64: - case DECIMAL128: { + case DECIMAL128: + case DECIMAL256: { // https://github.com/mysql/mysql-connector-j/blob/release/5.1/src/com/mysql/jdbc/ResultSetMetaData.java // in function: int getPrecision(int column) // f.getDecimals() > 0 ? clampedGetLength(f) - 1 + f.getPrecisionAdjustFactor() @@ -296,6 +297,7 @@ public int getMysqlDecimals(Type type) { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: return ((ScalarType) type).decimalScale(); case FLOAT: case DOUBLE: diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/exceptions/NotSupportedException.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/exceptions/NotSupportedException.java new file mode 100644 index 000000000000000..bb707b6562088a5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/exceptions/NotSupportedException.java @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.exceptions; + +/** + * Exception for calling function only implement in bound expression or plan. + */ +public class NotSupportedException extends RuntimeException { + public NotSupportedException(String msg) { + super(String.format("Not Supported: %s", msg)); + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java index 6ed045a3004f294..33fe8e85dd562e1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java @@ -209,7 +209,8 @@ private Map evalOnBE(Map> paramMa type = DateTimeV2Type.of(pScalarType.getScale()); } else if (primitiveType == PrimitiveType.DECIMAL32 || primitiveType == PrimitiveType.DECIMAL64 - || primitiveType == PrimitiveType.DECIMAL128) { + || primitiveType == PrimitiveType.DECIMAL128 + || primitiveType == PrimitiveType.DECIMAL256) { type = DecimalV3Type.createDecimalV3Type( pScalarType.getPrecision(), pScalarType.getScale()); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java index eaa3c1ad2a96da1..002849bb8166cab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java @@ -68,9 +68,9 @@ public DataType getDataType() throws UnboundException { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { int retPercision = t1.getPrecision() + t2.getScale() + Config.div_precision_increment; - Preconditions.checkState(retPercision <= DecimalV3Type.MAX_DECIMAL128_PRECISION, + Preconditions.checkState(retPercision <= DecimalV3Type.MAX_DECIMAL256_PRECISION, "target precision " + retPercision + " larger than precision " - + DecimalV3Type.MAX_DECIMAL128_PRECISION + " in Divide return type"); + + DecimalV3Type.MAX_DECIMAL256_PRECISION + " in Divide return type"); int retScale = t1.getScale() + t2.getScale() + Config.div_precision_increment; int targetPercision = retPercision; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java index a42c3fa5c288e1c..87bd92850d4063f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -52,7 +53,12 @@ public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) int retPercision = t1.getPrecision() + t2.getPrecision(); int retScale = t1.getScale() + t2.getScale(); if (retPercision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { - retPercision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); + if (enableDecimal256) { + retPercision = DecimalV3Type.MAX_DECIMAL256_PRECISION; + } else { + retPercision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + } } Preconditions.checkState(retPercision >= retScale, "scale " + retScale + " larger than precision " + retPercision diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForSum.java index 1409a1d559cf363..d7665f3f262c27e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForSum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForSum.java @@ -21,6 +21,8 @@ import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.qe.ConnectContext; + /** ComputePrecisionForSum */ public interface ComputePrecisionForSum extends ComputePrecision { @Override @@ -28,9 +30,12 @@ default FunctionSignature computePrecision(FunctionSignature signature) { DataType argumentType = getArgumentType(0); if (signature.getArgType(0) instanceof DecimalV3Type) { DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType); + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); return signature.withArgumentType(0, decimalV3Type) .withReturnType(DecimalV3Type.createDecimalV3Type( - DecimalV3Type.MAX_DECIMAL128_PRECISION, decimalV3Type.getScale())); + enableDecimal256 ? DecimalV3Type.MAX_DECIMAL256_PRECISION + : DecimalV3Type.MAX_DECIMAL128_PRECISION, + decimalV3Type.getScale())); } else { return signature; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java index 6eea3d9e41fa089..512421032b6d9d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java @@ -36,6 +36,8 @@ import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.qe.ConnectContext; + import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -88,6 +90,7 @@ public void checkLegalityBeforeTypeCoercion() { public FunctionSignature computePrecision(FunctionSignature signature) { DataType argumentType = getArgumentType(0); if (signature.getArgType(0) instanceof DecimalV3Type) { + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType); // DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE should do cast int precision = decimalV3Type.getPrecision(); @@ -95,14 +98,22 @@ public FunctionSignature computePrecision(FunctionSignature signature) { if (decimalV3Type.getScale() < ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) { scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE; precision = precision - decimalV3Type.getScale() + scale; - if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { - precision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + if (enableDecimal256) { + if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) { + precision = DecimalV3Type.MAX_DECIMAL256_PRECISION; + } + } else { + if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { + precision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + } } } decimalV3Type = DecimalV3Type.createDecimalV3Type(precision, scale); return signature.withArgumentType(0, decimalV3Type) .withReturnType(DecimalV3Type.createDecimalV3Type( - DecimalV3Type.MAX_DECIMAL128_PRECISION, decimalV3Type.getScale() + enableDecimal256 ? DecimalV3Type.MAX_DECIMAL256_PRECISION + : DecimalV3Type.MAX_DECIMAL128_PRECISION, + decimalV3Type.getScale() )); } else { return signature; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java index 7501d3b8db59b7c..c30b6c9024ff39d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java @@ -20,7 +20,9 @@ import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.exceptions.NotSupportedException; import org.apache.doris.nereids.types.coercion.FractionalType; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; @@ -29,14 +31,20 @@ import java.util.Map; import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * Decimal type in Nereids. */ @Developing public class DecimalV3Type extends FractionalType { + private static final Logger LOG = LoggerFactory.getLogger(DecimalV3Type.class); + public static final int MAX_DECIMAL32_PRECISION = 9; public static final int MAX_DECIMAL64_PRECISION = 18; public static final int MAX_DECIMAL128_PRECISION = 38; + public static final int MAX_DECIMAL256_PRECISION = 76; public static final DecimalV3Type WILDCARD = new DecimalV3Type(-1, -1); public static final DecimalV3Type SYSTEM_DEFAULT = new DecimalV3Type(MAX_DECIMAL128_PRECISION, DEFAULT_SCALE); @@ -99,12 +107,17 @@ public static DecimalV3Type createDecimalV3Type(int precision) { /** createDecimalV3Type. */ public static DecimalV3Type createDecimalV3Type(int precision, int scale) { - Preconditions.checkArgument(precision > 0 && precision <= MAX_DECIMAL128_PRECISION, - "precision should in (0, " + MAX_DECIMAL128_PRECISION + "], but real precision is " + precision); + Preconditions.checkArgument(precision > 0 && precision <= MAX_DECIMAL256_PRECISION, + "precision should in (0, " + MAX_DECIMAL256_PRECISION + "], but real precision is " + precision); Preconditions.checkArgument(scale >= 0, "scale should not smaller than 0, but real scale is " + scale); Preconditions.checkArgument(precision >= scale, "precision should not smaller than scale," + " but precision is " + precision, ", scale is " + scale); - return new DecimalV3Type(precision, scale); + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); + if (precision > MAX_DECIMAL128_PRECISION && !enableDecimal256) { + throw new NotSupportedException("Datatype DecimalV3 with precision " + precision + ", which is greater than 38 is disabled by default. set enable_decimal256 = true to enable it."); + } else { + return new DecimalV3Type(precision, scale); + } } public static DecimalV3Type createDecimalV3Type(BigDecimal bigDecimal) { @@ -124,7 +137,12 @@ private static DataType widerDecimalV3Type( boolean overflowToDouble) { int scale = Math.max(leftScale, rightScale); int range = Math.max(leftPrecision - leftScale, rightPrecision - rightScale); - if (range + scale > MAX_DECIMAL128_PRECISION && overflowToDouble) { + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); + if (range + scale > MAX_DECIMAL128_PRECISION && overflowToDouble && !enableDecimal256) { + LOG.warn("sum test widerDecimalV3Type return double"); + for (StackTraceElement ste : Thread.currentThread().getStackTrace()) { + LOG.warn(ste.toString()); + } return DoubleType.INSTANCE; } return DecimalV3Type.createDecimalV3Type(range + scale, scale); @@ -193,8 +211,15 @@ public int width() { return 4; } else if (precision <= MAX_DECIMAL64_PRECISION) { return 8; - } else { + } else if (precision <= MAX_DECIMAL128_PRECISION) { return 16; + } else { + boolean enableDecimal256 = ConnectContext.get().getSessionVariable().enableDecimal256(); + if (enableDecimal256) { + return 32; + } else { + return 16; + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java index 296cac55bc7e0d7..26e8cb7f5e0b13c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java @@ -50,6 +50,7 @@ import org.apache.doris.mysql.MysqlProto; import org.apache.doris.mysql.MysqlSerializer; import org.apache.doris.mysql.MysqlServerStatusFlag; +import org.apache.doris.nereids.exceptions.NotSupportedException; import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.minidump.MinidumpUtils; import org.apache.doris.nereids.parser.NereidsParser; @@ -306,9 +307,13 @@ private void handleQuery(MysqlCommand mysqlCommand) { if (mysqlCommand == MysqlCommand.COM_QUERY && ctx.getSessionVariable().isEnableNereidsPlanner()) { try { stmts = new NereidsParser().parseSQL(originStmt); + } catch (NotSupportedException e) { + // Parse sql failed, audit it and return + handleQueryException(e, originStmt, null, null); + return; } catch (Exception e) { // TODO: We should catch all exception here until we support all query syntax. - LOG.debug("Nereids parse sql failed. Reason: {}. Statement: \"{}\".", + LOG.warn("Nereids parse sql failed. Reason: {}. Statement: \"{}\".", e.getMessage(), originStmt); } } @@ -390,6 +395,11 @@ private void handleQueryException(Throwable throwable, String origStmt, ctx.getState().setError(((UserException) throwable).getMysqlErrorCode(), throwable.getMessage()); // set it as ANALYSIS_ERR so that it won't be treated as a query failure. ctx.getState().setErrType(QueryState.ErrType.ANALYSIS_ERR); + } else if (throwable instanceof NotSupportedException) { + LOG.warn("Process one query failed because.", throwable); + ctx.getState().setError(ErrorCode.ERR_NOT_SUPPORTED_YET, throwable.getMessage()); + // set it as ANALYSIS_ERR so that it won't be treated as a query failure. + ctx.getState().setErrType(QueryState.ErrType.ANALYSIS_ERR); } else { // Catch all throwable. // If reach here, maybe palo bug. diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 4f0390799726f23..3036e7daab1d852 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -423,6 +423,8 @@ public class SessionVariable implements Serializable, Writable { public static final String FASTER_FLOAT_CONVERT = "faster_float_convert"; + public static final String ENABLE_DECIMAL256 = "enable_decimal256"; + public static final List DEBUG_VARIABLES = ImmutableList.of( SKIP_DELETE_PREDICATE, SKIP_DELETE_BITMAP, @@ -1254,6 +1256,10 @@ public void setIgnoreShapePlanNodes(String ignoreShapePlanNodes) { "the plan node type which is ignored in 'explain shape plan' command"}) public String ignoreShapePlanNodes = ""; + @VariableMgr.VarAttr(name = ENABLE_DECIMAL256, + description = {"控制是否在计算过程中使用Decimal256类型", "Set to true to enable Decimal256 type"}) + public boolean enableDecimal256 = false; + // If this fe is in fuzzy mode, then will use initFuzzyModeVariables to generate some variables, // not the default value set in the code. public void initFuzzyModeVariables() { @@ -2388,6 +2394,8 @@ public TQueryOptions toThrift() { tResult.setFasterFloatConvert(fasterFloatConvert); + tResult.setEnableDecimal256(enableNereidsPlanner && enableDecimal256); + return tResult; } @@ -2717,6 +2725,10 @@ public static boolean enableAggState() { return connectContext.getSessionVariable().enableAggState; } + public boolean enableDecimal256() { + return enableDecimal256; + } + public void checkAnalyzeTimeFormat(String time) { try { DateTimeFormatter timeFormatter = DateTimeFormatter.ofPattern("HH:mm:ss"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java b/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java index 0c36cb69e768c9e..0500814ed526c9e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/cache/PartitionRange.java @@ -189,6 +189,7 @@ public boolean init(Type type, LiteralExpr expr) { case DECIMAL32: case DECIMAL64: case DECIMAL128: + case DECIMAL256: case CHAR: case VARCHAR: case STRING: diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java index 3bd22cdbf24e9f6..509e78ffb8b6411 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java @@ -404,7 +404,8 @@ private Map> calcConstExpr(Mapselect k3, CAST(k3 AS DECIMALV3(18, 10)) from test_arithmetic_expressions_64; ++---------------------+-------------------------------+ +| k3 | cast(k3 as DECIMALV3(18, 10)) | ++---------------------+-------------------------------+ +| 333333333333.333333 | -552734400.8095512496 | +| 499999999999.999999 | 93235602.4711502064 | +| 999999999999.999999 | 186471204.9423014128 | +| 4.000000 | 4.0000000000 | ++---------------------+-------------------------------+ +4 rows in set (0.39 sec) +*/ + + // decimal128 + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_128_1`"; + sql """ + CREATE TABLE IF NOT EXISTS `test_arithmetic_expressions_128_1` ( + `k1` decimalv3(38, 6) NULL COMMENT "", + `k2` decimalv3(38, 6) NULL COMMENT "", + `k3` decimalv3(38, 6) NULL COMMENT "" + ) ENGINE=OLAP + COMMENT "OLAP" + DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """insert into test_arithmetic_expressions_128_1 values(1, 99999999999999999999999999999999.999999, 99999999999999999999999999999999.999999), + (2, 49999999999999999999999999999999.999999, 49999999999999999999999999999999.999999), + (3, 33333333333333333333333333333333.333333, 33333333333333333333333333333333.333333), + (4.444444, 2.222222, 3.333333);""" + qt_decimal128_select_all "select * from test_arithmetic_expressions_128_1 order by k1, k2;" + // fix cast + // qt_decimal128_cast "select k3, CAST(k3 AS DECIMALV3(38, 10)) from test_arithmetic_expressions_128_1 order by 1, 2;" + /* + qt_decimal128_multiply_0 "select k1 * k2 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_arith_union "select * from (select k1 * k2 from test_arithmetic_expressions_128_1 union all select k3 from test_arithmetic_expressions_128_1) a order by 1" + qt_decimal128_multiply_1 "select k1 * k2 * k3 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_multiply_2 "select k1 * k2 * k3 * k1 * k2 * k3 from test_arithmetic_expressions_128_1 order by k1" + qt_decimal128_multiply_div "select k1 * k2 / k3 * k1 * k2 * k3 from test_arithmetic_expressions_128_1 order by k1" + */ + + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_128_2`"; + sql """ + CREATE TABLE IF NOT EXISTS test_arithmetic_expressions_128_2 ( + `a` DECIMALV3(38, 3) NOT NULL, + `b` DECIMALV3(38, 3) NOT NULL, + `c` DECIMALV3(38, 3) NOT NULL, + `d` DECIMALV3(38, 3) NOT NULL, + `e` DECIMALV3(38, 3) NOT NULL, + `f` DECIMALV3(38, 3) NOT NULL, + `g` DECIMALV3(38, 3) NOT NULL, + `h` DECIMALV3(38, 3) NOT NULL, + `i` DECIMALV3(38, 3) NOT NULL, + `j` DECIMALV3(38, 3) NOT NULL, + `k` DECIMALV3(38, 3) NOT NULL + ) DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1"); + """ + + sql """ + insert into test_arithmetic_expressions_128_2 values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999); + """ + qt_decimal128_select_all_2 "select * from test_arithmetic_expressions_128_2 order by a" + /* + qt_decimal128_mixed_calc_0 "select a + b + c from test_arithmetic_expressions_128_2;" + qt_decimal128_mixed_calc_1 "select (a + b + c) * d from test_arithmetic_expressions_128_2;" + qt_decimal128_mixed_calc_2 "select (a + b + c) / d from test_arithmetic_expressions_128_2;" + qt_decimal128_mixed_calc_3 "select a + b + c + d + e + f + g + h + i + j + k from test_arithmetic_expressions_128_2;" + */ + + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + qt_decimal128_cast256_cast "select k3, CAST(k3 AS DECIMALV3(76, 10)) from test_arithmetic_expressions_128_1 order by 1, 2;" + qt_decimal128_cast256_calc_0 "select cast(k1 as decimalv3(76, 6)) + k2 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_cast256_calc_1 "select cast(k2 as decimalv3(76, 6)) - k1 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_cast256_calc_2 "select cast(k1 as decimalv3(76, 6)) * k2 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_cast256_calc_4 "select k2, k1, cast(k2 as decimalv3(76, 6)) / k1 a from test_arithmetic_expressions_128_1 order by 1, 2;" + qt_decimal128_cast256_calc_5 "select k2, k1, cast(k2 as decimalv3(76, 6)) % k1 a from test_arithmetic_expressions_128_1 order by 1, 2;" + + qt_decimal128_cast256_calc_6 "select * from (select cast(k1 as decimalv3(76, 6)) * k2 from test_arithmetic_expressions_128_1 union all select k3 from test_arithmetic_expressions_128_1) a order by 1" + // overflow + qt_decimal128_cast256_calc_7 "select cast(k1 as decimalv3(76, 6)) * k2 * k3 a from test_arithmetic_expressions_128_1 order by 1;" + qt_decimal128_cast256_calc_8 "select cast(k1 as decimalv3(76, 6)) * k2 * k3 * k1 * k2 * k3 from test_arithmetic_expressions_128_1 order by 1" + // qt_decimal128_cast256_calc_9 "select cast(k1 as decimalv3(76, 6)) * k2 / k3 * k1 * k2 * k3 from test_arithmetic_expressions_128_1 order by 1" + + qt_decimal128_cast256_mixed_calc_0 "select cast(a as decimalv3(39, 4)) + b + c from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_1 "select cast((a + b + c) as decimalv3(39, 4)) * d from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_2 "select cast((a + b + c) as decimalv3(39, 4)) / d from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_3 "select cast(a as decimalv3(39, 4)) + b + c + d + e + f + g + h + i + j + k from test_arithmetic_expressions_128_2 order by 1;" + + qt_decimal128_cast256_mixed_calc_4 "select cast(a as decimalv3(76, 6)) + b + c from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_5 "select cast((a + b + c) as decimalv3(76, 6)) * d from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_6 "select cast((a + b + c) as decimalv3(76, 6)) / d from test_arithmetic_expressions_128_2 order by 1;" + qt_decimal128_cast256_mixed_calc_7 "select cast(a as decimalv3(76, 6)) + b + c + d + e + f + g + h + i + j + k from test_arithmetic_expressions_128_2 order by 1;" + +/* +mysql [test]>select k3, CAST(k3 AS DECIMALV3(38, 10)) from test_arithmetic_expressions_128_1; ++-----------------------------------------+------------------------------------------+ +| k3 | cast(k3 as DECIMALV3(38, 10)) | ++-----------------------------------------+------------------------------------------+ +| 33333333333333333333333333333333.333333 | -9999999999999999999999999999.9999999999 | +| 99999999999999999999999999999999.999999 | -9999999999999999999999999999.9999999999 | +| 49999999999999999999999999999999.999999 | 9999999999999999999999999999.9999999999 | +| 4.000000 | 4.0000000000 | ++-----------------------------------------+------------------------------------------+ +4 rows in set (0.07 sec) +*/ + + // decimal256 + /* + mysql [regression_test_datatype_p0_decimalv3]>select CAST(k3 AS DECIMALV3(76, 19)) from test_arithmetic_expressions_256_0; ++---------------------------------------------------------------------------------+ +| cast(k3 as DECIMALV3(76, 19)) | ++---------------------------------------------------------------------------------+ +| 3213777273360060490676974488410532053153505213067024283781.6774441511766597376 | +| -1717218125670499520334383174682575559673329346810002612127.4678374963165481728 | +| -5151654377011498561003149524047726679019988040430007836382.4035124889496445184 | ++---------------------------------------------------------------------------------+ + */ + /* + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_256_1`" + sql """ + CREATE TABLE IF NOT EXISTS `test_arithmetic_expressions_256_1` ( + `k1` decimalv3(76, 9) NULL COMMENT "", + `k2` decimalv3(76, 10) NULL COMMENT "", + `k3` decimalv3(76, 11) NULL COMMENT "" + ) ENGINE=OLAP + DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """insert into test_arithmetic_expressions_256_1 values(1, 999999999999999999999999999999999999999999999999999999999999999999.9999999999, 99999999999999999999999999999999999999999999999999999999999999999.99999999999), + (2, 499999999999999999999999999999999999999999999999999999999999999999.9999999999, 49999999999999999999999999999999999999999999999999999999999999999.99999999999), + (3, 333333333333333333333333333333333333333333333333333333333333333333.3333333333, 33333333333333333333333333333333333333333333333333333333333333333.33333333333);""" + qt_decimal256_arith_select_all "select * from test_arithmetic_expressions_256_1 order by k1, k2, k3;" + qt_decimal256_arith_plus "select k1 + k2 from test_arithmetic_expressions_256_1 order by 1;" + qt_decimal256_arith_minus "select k2 - k1 from test_arithmetic_expressions_256_1 order by 1;" + qt_decimal256_arith_multiply "select k1 * k2 from test_arithmetic_expressions_256_1 order by 1;" + qt_decimal256_arith_div "select k2 / k1 from test_arithmetic_expressions_256_1 order by 1;" + qt_decimal256_arith_union "select * from (select k1 * k2 from test_arithmetic_expressions_256_1 union all select k3 from test_arithmetic_expressions_256_1) a order by 1" + + qt_decimal256_multiply_1 "select k1 * k2 * k3 a from test_arithmetic_expressions_256_1 order by 1;" + qt_decimal256_multiply_2 "select k1 * k2 * k3 * k1 * k2 * k3 from test_arithmetic_expressions_256_1 order by k1" + qt_decimal256_multiply_div "select k1 * k2 / k3 * k1 * k2 * k3 from test_arithmetic_expressions_256_1 order by k1" + + qt_decimal256_arith_multiply_const "select k1 * 2.0 from test_arithmetic_expressions_256_1 order by 1;" + + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_256_2`"; + sql """ + CREATE TABLE IF NOT EXISTS test_arithmetic_expressions_256_2 ( + `a` DECIMALV3(76, 3) NOT NULL, + `b` DECIMALV3(76, 3) NOT NULL, + `c` DECIMALV3(76, 3) NOT NULL, + `d` DECIMALV3(76, 3) NOT NULL, + `e` DECIMALV3(76, 3) NOT NULL, + `f` DECIMALV3(76, 3) NOT NULL, + `g` DECIMALV3(76, 3) NOT NULL, + `h` DECIMALV3(76, 3) NOT NULL, + `i` DECIMALV3(76, 3) NOT NULL, + `j` DECIMALV3(76, 3) NOT NULL, + `k` DECIMALV3(76, 3) NOT NULL + ) DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1"); + """ + + sql """ + insert into test_arithmetic_expressions_256_2 values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999); + """ + qt_decimal256_select_all_2 "select * from test_arithmetic_expressions_256_2 order by a" + + qt_decimal256_mixed_calc_0 "select a + b + c from test_arithmetic_expressions_256_2;" + qt_decimal256_mixed_calc_1 "select (a + b + c) * d from test_arithmetic_expressions_256_2;" + qt_decimal256_mixed_calc_2 "select (a + b + c) / d from test_arithmetic_expressions_256_2;" + qt_decimal256_mixed_calc_3 "select a + b + c + d + e + f + g + h + i + j + k from test_arithmetic_expressions_256_2;" + + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_256_3`" + sql """ + CREATE TABLE IF NOT EXISTS `test_arithmetic_expressions_256_3` ( + `k1` decimalv3(76, 0) NULL COMMENT "", + `k2` decimalv3(76, 1) NULL COMMENT "", + `k3` decimalv3(76, 2) NULL COMMENT "" + ) ENGINE=OLAP + DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """insert into test_arithmetic_expressions_256_3 values(1, 999999999999999999999999999999999999999999999999999999999999999999999999999.9, 99999999999999999999999999999999999999999999999999999999999999999999999999.99), + (2, 499999999999999999999999999999999999999999999999999999999999999999999999999.9, 49999999999999999999999999999999999999999999999999999999999999999999999999.99), + (3, 333333333333333333333333333333333333333333333333333333333333333333333333333.3, 33333333333333333333333333333333333333333333333333333333333333333333333333.33);""" + qt_decimal256_arith_3 "select k1, k2, k1 * k2 a from test_arithmetic_expressions_256_3 order by k1, k2;" + + sql "DROP TABLE IF EXISTS `test_arithmetic_expressions_256_4`" + sql """ create table test_arithmetic_expressions_256_4 ( + id smallint, + fz decimal(27,9), + fzv3 decimalv3(76,9), + fm decimalv3(76,10)) + DISTRIBUTED BY HASH(`id`) BUCKETS auto + PROPERTIES + ( + "replication_num" = "1" + ); """ + + sql """ insert into test_arithmetic_expressions_256_4 values (1,92594283.129196000,92594283.129196000,147202.0000000000); """ + sql """ insert into test_arithmetic_expressions_256_4 values (2,107684988.257976000,107684988.257976000,148981.0000000000); """ + sql """ insert into test_arithmetic_expressions_256_4 values (3,76891560.464178000,76891560.464178000,106161.0000000000); """ + sql """ insert into test_arithmetic_expressions_256_4 values (4,277170831.851350000,277170831.851350000,402344.0000000000); """ + + qt_decimal256_div_v2_v3 """ select id, fz/fm as dec,fzv3/fm as decv3 from test_arithmetic_expressions_256_4 ORDER BY id; """ + + sql "drop table if exists test_arithmetic_expressions_256_5" + sql """ create table test_arithmetic_expressions_256_5 ( + id smallint, + v1 decimalv3(27,9), + v2 decimalv3(9,0), + v3 int ) + DISTRIBUTED BY HASH(`id`) BUCKETS auto + PROPERTIES + ( + "replication_num" = "1" + ); """ + + sql """ insert into test_arithmetic_expressions_256_5 values (1,92594283.129196000,1,1); """ + sql """ insert into test_arithmetic_expressions_256_5 values (2,107684988.257976000,3,3); """ + sql """ insert into test_arithmetic_expressions_256_5 values (3,76891560.464178000,5,5); """ + sql """ insert into test_arithmetic_expressions_256_5 values (4,277170831.851350000,7,7); """ + + qt_decimal256_mod """ select v1, v2, v1 % v2, v1 % v3 from test_arithmetic_expressions_256_5 ORDER BY id; """ + */ + } diff --git a/regression-test/suites/datatype_p0/decimalv3/test_decimalv3.groovy b/regression-test/suites/datatype_p0/decimalv3/test_decimalv3.groovy index 2b72c36867bba1d..d67f927f4b5e745 100644 --- a/regression-test/suites/datatype_p0/decimalv3/test_decimalv3.groovy +++ b/regression-test/suites/datatype_p0/decimalv3/test_decimalv3.groovy @@ -20,19 +20,42 @@ suite("test_decimalv3") { sql "CREATE DATABASE IF NOT EXISTS ${db}" sql "use ${db}" sql "drop table if exists test5" - sql '''CREATE TABLE test5 ( `a` decimalv3(38,18), `b` decimalv3(38,18) ) ENGINE=OLAP DUPLICATE KEY(`a`) COMMENT 'OLAP' DISTRIBUTED BY HASH(`a`) BUCKETS 1 PROPERTIES ( "replication_allocation" = "tag.location.default: 1" ) ''' - sql "insert into test5 values(50,2)" - sql "drop view if exists test5_v" - sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5" + sql '''CREATE TABLE test5 ( `a` decimalv3(38,18), `b` decimalv3(38,18) ) ENGINE=OLAP DUPLICATE KEY(`a`) COMMENT 'OLAP' DISTRIBUTED BY HASH(`a`) BUCKETS 1 PROPERTIES ( "replication_allocation" = "tag.location.default: 1" ) ''' + sql "insert into test5 values(50,2)" + sql "drop view if exists test5_v" + sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5" - qt_decimalv3 "select * from test5_v" - qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5" + qt_decimalv3 "select * from test5_v" + qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5" + + /* + sql "drop table if exists test_decimal256;" + sql """ create table test_decimal256(k1 decimal(76, 6), v1 decimal(76, 6)) + DUPLICATE KEY(`k1`, `v1`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + properties("replication_num" = "1"); """ + sql """insert into test_decimal256 values(1, 9999999999999999999999999999999999999999999999999999999999999999999999.999999), + (2, 4999999999999999999999999999999999999999999999999999999999999999999999.999999);""" + qt_decimalv3_0 "select * from test_decimal256 order by k1, v1; " + qt_decimalv3_1 "select * from test_decimal256 where v1 = 9999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + qt_decimalv3_2 "select * from test_decimal256 where v1 != 9999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + qt_decimalv3_3 "select * from test_decimal256 where v1 > 4999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + qt_decimalv3_4 "select * from test_decimal256 where v1 >= 4999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + qt_decimalv3_5 "select * from test_decimal256 where v1 < 9999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + qt_decimalv3_6 "select * from test_decimal256 where v1 <= 9999999999999999999999999999999999999999999999999999999999999999999999.999999 order by k1, v1; " + */ + + sql "set experimental_enable_nereids_planner =false;" + qt_aEb_test1 "select 0e0;" + qt_aEb_test2 "select 1e-1" + qt_aEb_test3 "select -1e-2" + qt_aEb_test4 "select 10.123456e10;" + qt_aEb_test5 "select 123456789e-10" + qt_aEb_test6 "select 0.123445e10;" + + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + qt_decimal256_cast_0 """ select cast("999999.999999" as decimal(76,6));""" + qt_decimal256_cast_1 """select cast("9999999999999999999999999999999999999999999999999999999999999999999999.999999" as decimal(76,6));""" - sql "set experimental_enable_nereids_planner =false;" - qt_aEb_test1 "select 0e0;" - qt_aEb_test2 "select 1e-1" - qt_aEb_test3 "select -1e-2" - qt_aEb_test4 "select 10.123456e10;" - qt_aEb_test5 "select 123456789e-10" - qt_aEb_test6 "select 0.123445e10;" } diff --git a/regression-test/suites/datatype_p0/decimalv3/test_decimalv3_overflow.groovy b/regression-test/suites/datatype_p0/decimalv3/test_decimalv3_overflow.groovy index bb4b4ba42d736ed..bf4b1ef220b0fa3 100644 --- a/regression-test/suites/datatype_p0/decimalv3/test_decimalv3_overflow.groovy +++ b/regression-test/suites/datatype_p0/decimalv3/test_decimalv3_overflow.groovy @@ -42,6 +42,7 @@ suite("test_decimalv3_overflow") { ); """ sql "insert into ${tblName2} values('2022-08-01', 705091149953414452.46)" + // modify case qt_sql """ select c2 / 10000 * c1 from ${tblName1}, ${tblName2}; """ sql """ set check_overflow_for_decimal=true; """ diff --git a/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy b/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy index 429f98b94a51f5f..d347c83ebe652a1 100644 --- a/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy +++ b/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy @@ -43,5 +43,57 @@ suite("test_predicate") { qt_select2 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ 1 FROM ${table1} WHERE CAST((CASE WHEN (TRUE IS NOT NULL) THEN '1.2' ELSE '1.2' END) AS FLOAT) = CAST(1.2 AS decimal(2,1));" qt_select3 "SELECT * FROM ${table1} WHERE k1 != 1.1 ORDER BY k1" + + // decimal256 + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + qt_select4 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ CAST((CASE WHEN (TRUE IS NOT NULL) THEN '1.2' ELSE '1.2' END) AS FLOAT) = CAST(1.2 AS decimal(76,1))" + qt_select5 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ 1 FROM ${table1} WHERE CAST((CASE WHEN (TRUE IS NOT NULL) THEN '1.2' ELSE '1.2' END) AS FLOAT) = CAST(1.2 AS decimal(76,1));" + qt_select6 "SELECT * FROM ${table1} WHERE k1 != cast(1.1 as decimalv3(76, 1)) ORDER BY k1" sql "drop table if exists ${table1}" + + qt_select256_1 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) > cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_2 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10)) > cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_3 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) >= cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_4 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999997 as decimalv3(76,10)) >= cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + + qt_select256_5 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10)) < cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10))" + qt_select256_6 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) < cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_7 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10)) <= cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_8 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) <= cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + + qt_select256_9 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) = cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10))" + qt_select256_10 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) = cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + + qt_select256_11 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) != cast(999999999999999999999999999999999999999999999999999999999999999999.9999999998 as decimalv3(76,10))" + qt_select256_12 "SELECT /*+ SET_VAR(enable_fold_constant_by_be = false) */ cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10)) != cast(999999999999999999999999999999999999999999999999999999999999999999.9999999999 as decimalv3(76,10))" + + + sql "DROP TABLE IF EXISTS `test_predicate_128_1`"; + sql """ + CREATE TABLE IF NOT EXISTS `test_predicate_128_1` ( + `k1` decimalv3(38, 6) NULL COMMENT "", + `k2` decimalv3(38, 6) NULL COMMENT "", + `k3` decimalv3(38, 6) NULL COMMENT "" + ) ENGINE=OLAP + COMMENT "OLAP" + DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """insert into test_predicate_128_1 values(1, 99999999999999999999999999999999.999999, 99999999999999999999999999999999.999999), + (2, 49999999999999999999999999999999.999999, 49999999999999999999999999999999.999999), + (3, 33333333333333333333333333333333.333333, 33333333333333333333333333333333.333333), + (4.444444, 2.222222, 3.333333);""" + qt_decimal256_select_all "select * from test_predicate_128_1 order by k1, k2;" + qt_decimal256_predicate_0 "select * from test_predicate_128_1 where k2 > (cast(33333333333333333333333333333333.333333 as decimalv3(76,7))) order by k1, k2;" + qt_decimal256_predicate_1 "select * from test_predicate_128_1 where k2 >= (cast(999999999999999999999999999999990.999999 as decimalv3(76,6)) / 10)order by k1, k2;" + + qt_decimal256_predicate_2 "select * from test_predicate_128_1 where k2 < (cast(49999999999999999999999999999999.999999 as decimalv3(76,7))) order by k1, k2;" + qt_decimal256_predicate_3 "select * from test_predicate_128_1 where k2 <= (cast(33333333333333333333333333333333.333333 as decimalv3(76,7))) order by k1, k2;" + + qt_decimal256_predicate_4 "select * from test_predicate_128_1 where k2 = (cast(99999999999999999999999999999999.999999 as decimalv3(76,7))) order by k1, k2;" + qt_decimal256_predicate_5 "select * from test_predicate_128_1 where k2 != (cast(99999999999999999999999999999999.999999 as decimalv3(76,7))) order by k1, k2;" + } diff --git a/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy b/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy new file mode 100644 index 000000000000000..34daa1ecabd663c --- /dev/null +++ b/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("aggregate_decimal256") { + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + sql "drop table if exists test_aggregate_decimal256_sum;" + sql """ create table test_aggregate_decimal256_sum(k1 int, v1 decimal(38, 6), v2 decimal(38, 6)) + DUPLICATE KEY(`k1`, `v1`, `v2`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + properties("replication_num" = "1"); """ + + sql """insert into test_aggregate_decimal256_sum values + (1, 1.000000, 99999999999999999999999999999999.999999), + (1, 1.000000, 99999999999999999999999999999999.999999), + (1, 1.000000, -999999.200002), + (1, 1.000000, 999999.200002), + (2, 11.000000, 99999999999999999999999999999999.999999), + (2, 11.000000, 99999999999999999999999999999999.999999), + (2, 11.000000, -999999.200002), + (2, 11.000000, 999999.200002);""" + sql "sync" + + qt_sql_sum_1 """ select k1, sum(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_sum where v1 = 1 group by k1 order by 1, 2; """ + qt_sql_sum_2 """ select k1, sum(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_sum where v1 = 11 group by k1 order by 1, 2; """ + qt_sql_sum_3 """ select k1, sum(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_sum group by k1 order by 1, 2; """ + qt_sql_sum_4 """ + select + k1, + sum(sum_val) + from + ( + ( + select + k1, + sum(cast(v2 as decimalv3(39, 6))) as sum_val + from + test_aggregate_decimal256_sum + where + v1 = 1 + group by k1 + ) + union + all ( + select + k1, + sum(cast(v2 as decimalv3(39, 6))) as sum_val + from + test_aggregate_decimal256_sum + where + v1 = 11 + group by k1 + ) + ) union1 group by k1 + order by 1, 2; + """ + + qt_sql_sum_5 """ select cast(v2 as decimalv3(39, 6)) v2_cast, sum(k1) from test_aggregate_decimal256_sum group by v2_cast order by 1, 2; """ + + sql """insert into test_aggregate_decimal256_sum values + (1, 1.000000, -999999.200002), + (1, 1.000000, 999999.200002), + (2, 11.000000, -999999.200002), + (2, 11.000000, 999999.200002);""" + sql "sync" + qt_sql_sum_6 """ select cast(v1 as decimalv3(39, 6)) v1_cast, cast(v2 as decimalv3(39, 6)) v2_cast, sum(k1) from test_aggregate_decimal256_sum group by v1_cast, v2_cast order by 1, 2, 3; """ + qt_sql_sum_7 """ select sum(cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_sum order by 1; """ + qt_sql_sum_8 """ select sum(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_sum order by 1; """ + qt_sql_sum_9 """ select sum(cast(v1 as decimalv3(39, 6))), sum(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_sum order by 1, 2; """ + + sql "drop table if exists test_aggregate_decimal256_avg;" + sql """ create table test_aggregate_decimal256_avg(k1 int, v1 decimal(38, 6), v2 decimal(38, 6)) + DUPLICATE KEY(`k1`, `v1`, `v2`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + properties("replication_num" = "1"); """ + + sql """insert into test_aggregate_decimal256_avg values + (1, 1.000000, 99999999999999999999999999999999.999999), + (1, 1.000000, 99999999999999999999999999999999.999999), + (1, 1.000000, -999999.200002), + (1, 1.000000, 999999.200002), + (2, 11.000000, 99999999999999999999999999999999.999999), + (2, 11.000000, 99999999999999999999999999999999.999999), + (2, 11.000000, -999999.200002), + (2, 11.000000, 999999.200002);""" + sql "sync" + qt_sql_avg_1 """ select k1, avg(cast(v2 as decimalv3(76, 6))) from test_aggregate_decimal256_avg where v1 = 1 group by k1 order by 1, 2; """ + qt_sql_avg_2 """ select k1, avg(cast(v2 as decimalv3(76, 6))) from test_aggregate_decimal256_avg where v1 = 11 group by k1 order by 1, 2; """ + qt_sql_avg_3 """ select k1, avg(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2; """ + qt_sql_avg_4 """ + select + k1, + avg(avg_val) + from + ( + ( + select + k1, + avg(cast(v2 as decimalv3(39, 6))) as avg_val + from + test_aggregate_decimal256_avg + where + v1 = 1 + group by k1 + ) + union + all ( + select + k1, + avg(cast(v2 as decimalv3(39, 6))) as avg_val + from + test_aggregate_decimal256_avg + where + v1 = 11 + group by k1 + ) + ) union1 group by k1 + order by 1, 2; + """ + qt_sql_avg_5 """ select avg(cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_avg_6 """ select avg(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_avg_7 """ select avg(cast(v1 as decimalv3(39, 6))), avg(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1, 2; """ + + qt_sql_max_1 """ select k1, max(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2; """ + qt_sql_max_2 """ select max(cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_max_3 """ select max(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_max_4 """ select max(cast(v1 as decimalv3(39, 6))), max(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + + qt_sql_min_1 """ select k1, min(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2; """ + qt_sql_min_2 """ select min(cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_min_3 """ select min(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + qt_sql_min_4 """ select min(cast(v1 as decimalv3(39, 6))), min(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1; """ + + qt_sql_count_1 """ select k1, count(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2; """ + qt_sql_count_2 """ select k1, count(cast(v1 as decimalv3(39, 6))), count(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2, 3; """ + qt_sql_count_3 """ select count(cast(v1 as decimalv3(39, 6))), count(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1, 2; """ +} diff --git a/regression-test/suites/query_p0/join/test_join_decimal256.groovy b/regression-test/suites/query_p0/join/test_join_decimal256.groovy new file mode 100644 index 000000000000000..7f6768ee3bf0703 --- /dev/null +++ b/regression-test/suites/query_p0/join/test_join_decimal256.groovy @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The cases is copied from https://github.com/trinodb/trino/tree/master +// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/aggregate +// and modified by Doris. + +suite("join_decimal256") { + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + sql "drop table if exists test_join_decimal256_0;" + sql """ create table test_join_decimal256_0(k1 int, v1 decimal(38, 6), v2 decimal(38, 6)) + DUPLICATE KEY(`k1`, `v1`, `v2`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + properties("replication_num" = "1"); """ + + sql """insert into test_join_decimal256_0 values + (10, 10.000000, 99999999999999999999999999999999.999999), (10, 10.000000, 0.000001), (10, -10.000000, -0.000001), + (110, 110.000000, 99999999999999999999999999999999.999999), (110, 110.000000, 0.000001), (110, -110.000000, -0.000001);""" + + sql "drop table if exists test_join_decimal256_1;" + sql """ create table test_join_decimal256_1(k1 int, v1 decimal(38, 6), v2 decimal(38, 6)) + DUPLICATE KEY(`k1`, `v1`, `v2`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + properties("replication_num" = "1"); """ + + sql """insert into test_join_decimal256_1 values + (11, 10.000000, 99999999999999999999999999999999.999999), (111, 10.000000, 99999999999999999999999999999999.999999), + (11, 10.000000, 0.000001), (111, 10.000000, 0.000001), + (11, -10.000000, -0.000001), (111, -10.000000, -0.000001), + (11, 110.000000, 99999999999999999999999999999999.999999),(111, 110.000000, 99999999999999999999999999999999.999999), + (11, 110.000000, 0.000001),(111, 110.000000, 0.000001), + (11, -110.000000, -0.000001), (111, -110.000000, -0.000001);""" + sql "sync" + + qt_join_1 """ + select + t0.v2_cast, t1.v2_cast, t0.k1, t0.v1, t1.k1, t1.v1 + from + ( + select + k1, + v1, + cast(v2 as decimal(76, 6)) v2_cast + from + test_join_decimal256_0 + ) t0 + inner join ( + select + k1, + v1, + cast(v2 as decimal(76, 6)) v2_cast + from + test_join_decimal256_1 + ) t1 on t0.v2_cast = t1.v2_cast + order by + 1,2,3,4,5,6; + """ + + qt_join_2 """ + select + t0.v1_cast, t0.v2_cast, t1.v1_cast, t1.v2_cast, t0.k1, t1.k1 + from + ( + select + k1, + cast(v1 as decimal(76, 6)) v1_cast, + cast(v2 as decimal(76, 6)) v2_cast + from + test_join_decimal256_0 + ) t0 + inner join ( + select + k1, + cast(v1 as decimal(76, 6)) v1_cast, + cast(v2 as decimal(76, 6)) v2_cast + from + test_join_decimal256_1 + ) t1 on t0.v1_cast = t1.v1_cast and t0.v2_cast = t1.v2_cast + order by + 1,2,3,4,5,6; + """ +} \ No newline at end of file