diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala index f772c909c8f2..c0a32a034670 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala @@ -335,6 +335,24 @@ class GlutenClickHouseDecimalSuite spark.sql("drop table if exists decimals_test") } } + + test("test castornull") { + // prepare + val createSql = + "create table decimals_cast_test(a decimal(18,8)) using parquet" + val inserts = + "insert into decimals_cast_test values(123456789.12345678)" + spark.sql(createSql) + + try { + spark.sql(inserts) + val q1 = "select cast(a as decimal(9,2)) from decimals_cast_test" + compareResultsAgainstVanillaSpark(q1, compareResult = true, _ => {}) + } finally { + spark.sql("drop table if exists decimals_cast_test") + } + } + // FIXME: Support AVG for Decimal Type Seq("true", "false").foreach { allowPrecisionLoss => diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp index 26d8e0deb3a9..a2d5dec73945 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp @@ -121,23 +121,32 @@ struct SparkDecimalBinaryOperation if constexpr (Mode == OpMode::Effect) { - return executeDecimalImpl( + return executeDecimalImpl>( left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); } if (calculateWith256(*arguments[0].type.get(), *arguments[1].type.get())) { - return executeDecimalImpl( + return executeDecimalImpl( left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); } - return executeDecimalImpl( - left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), result.getScale()); + if (is_division && max_scale - left.getScale() + max_scale > DataTypeDecimal::maxPrecision()) + { + return executeDecimalImpl( + left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + } + else + { + return executeDecimalImpl>( + left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result); + } } private: // ResultDataType e.g. DataTypeDecimal - template + template static ColumnPtr executeDecimalImpl( const auto & left, const auto & right, @@ -152,34 +161,29 @@ struct SparkDecimalBinaryOperation using RightFieldType = typename RightDataType::FieldType; using ResultFieldType = typename ResultDataType::FieldType; - using NativeResultType = NativeType; using ColVecResult = ColumnVectorOrDecimal; - size_t max_scale; - if constexpr (is_multiply) - max_scale = left.getScale() + right.getScale(); - else - max_scale = std::max(resultDataType.getScale(), std::max(left.getScale(), right.getScale())); + size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), resultDataType.getScale()); - NativeResultType scale_left = [&] + ScaleDataType scale_left = [&] { if constexpr (is_multiply) - return NativeResultType{1}; + return ScaleDataType{1}; // cast scale same to left auto diff_scale = max_scale - left.getScale(); if constexpr (is_division) - return DecimalUtils::scaleMultiplier(diff_scale + max_scale); + return DecimalUtils::scaleMultiplier(diff_scale + max_scale); else - return DecimalUtils::scaleMultiplier(diff_scale); + return DecimalUtils::scaleMultiplier(diff_scale); }(); - const NativeResultType scale_right = [&] + const ScaleDataType scale_right = [&] { if constexpr (is_multiply) - return NativeResultType{1}; + return ScaleDataType{1}; else - return DecimalUtils::scaleMultiplier(max_scale - right.getScale()); + return DecimalUtils::scaleMultiplier(max_scale - right.getScale()); }(); @@ -266,17 +270,19 @@ struct SparkDecimalBinaryOperation return ColumnNullable::create(std::move(col_res), std::move(col_null_map_to)); } - template + template static static void NO_INLINE process( const auto & a, const auto & b, ResultContainerType & result_container, - const NativeResultType & scale_a, - const NativeResultType & scale_b, + const ScaleDataType & scale_a, + const ScaleDataType & scale_b, ColumnUInt8::Container & vec_null_map_to, const ResultDataType & resultDataType, const size_t & max_scale) { + using NativeResultType = NativeType; + size_t size; if constexpr (op_case == OpCase::LeftConstant) size = b.size(); @@ -303,14 +309,14 @@ struct SparkDecimalBinaryOperation } else if constexpr (op_case == OpCase::LeftConstant) { - auto scaled_a = applyScaled(unwrap(a, 0), scale_a); + ScaleDataType scaled_a = applyScaled(unwrap(a, 0), scale_a); for (size_t i = 0; i < size; ++i) { NativeResultType res; if (calculate( scaled_a, unwrap(b, i), - static_cast(0), + static_cast(0), scale_b, res, resultDataType, @@ -322,7 +328,7 @@ struct SparkDecimalBinaryOperation } else if constexpr (op_case == OpCase::RightConstant) { - auto scaled_b = applyScaled(unwrap(b, 0), scale_b); + ScaleDataType scaled_b = applyScaled(unwrap(b, 0), scale_b); for (size_t i = 0; i < size; ++i) { @@ -331,7 +337,7 @@ struct SparkDecimalBinaryOperation unwrap(a, i), scaled_b, scale_a, - static_cast(0), + static_cast(0), res, resultDataType, max_scale)) @@ -343,12 +349,12 @@ struct SparkDecimalBinaryOperation } // ResultNativeType = Int32/64/128/256 - template + template static NO_SANITIZE_UNDEFINED bool calculate( const LeftNativeType l, const RightNativeType r, - const NativeResultType & scale_left, - const NativeResultType & scale_right, + const ScaleDataType & scale_left, + const ScaleDataType & scale_right, NativeResultType & res, const ResultDataType & resultDataType, const size_t & max_scale) @@ -361,12 +367,12 @@ struct SparkDecimalBinaryOperation return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); } - template + template static NO_SANITIZE_UNDEFINED bool calculateImpl( const LeftNativeType & l, const RightNativeType & r, - const NativeResultType & scale_left, - const NativeResultType & scale_right, + const ScaleDataType & scale_left, + const ScaleDataType & scale_right, NativeResultType & res, const ResultDataType & resultDataType, const size_t & max_scale) @@ -410,13 +416,21 @@ struct SparkDecimalBinaryOperation return elem[i].value; } - template - static ResultNativeType applyScaled(const NativeType & l, const ResultNativeType & scale) + template + static ScaleType applyScaled(const NativeType & l, const ScaleType & scale) { if (scale > 1) return common::mulIgnoreOverflow(l, scale); - return static_cast(l); + return static_cast(l); + } + + static size_t getMaxScaled(const size_t left_scale, const size_t right_scale, const size_t result_scale) + { + if constexpr (is_multiply) + return left_scale + right_scale; + else + return std::max(result_scale, std::max(left_scale, right_scale)); } }; diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 75ba2a115234..9f90ae3642e2 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1032,6 +1032,13 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG & acti String function_name = "sparkCastFloatTo" + non_nullable_output_type->getName(); function_node = toFunctionNode(actions_dag, function_name, args); } + else if ((isDecimal(non_nullable_input_type) && substrait_type.has_decimal())) + { + args.emplace_back(addColumn(actions_dag, std::make_shared(), substrait_type.decimal().precision())); + args.emplace_back(addColumn(actions_dag, std::make_shared(), substrait_type.decimal().scale())); + + function_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args); + } else { if (isString(non_nullable_input_type) && isInt(non_nullable_output_type))