Skip to content

Commit

Permalink
[GLUTEN-6975][CH] Fix decimal cast overflow exception
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee authored Sep 25, 2024
1 parent cca659a commit 6b1d63c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,24 @@ class GlutenClickHouseDecimalSuite
spark.sql("drop table if exists decimals_test")
}
}

test("test castornull") {
// prepare
val createSql =
"create table decimals_cast_test(a decimal(18,8)) using parquet"
val inserts =
"insert into decimals_cast_test values(123456789.12345678)"
spark.sql(createSql)

try {
spark.sql(inserts)
val q1 = "select cast(a as decimal(9,2)) from decimals_cast_test"
compareResultsAgainstVanillaSpark(q1, compareResult = true, _ => {})
} finally {
spark.sql("drop table if exists decimals_cast_test")
}
}

// FIXME: Support AVG for Decimal Type
Seq("true", "false").foreach {
allowPrecisionLoss =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,32 @@ struct SparkDecimalBinaryOperation

if constexpr (Mode == OpMode::Effect)
{
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, NativeType<typename ResultDataType::FieldType>>(
left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result);
}

if (calculateWith256<is_plus_minus, is_multiply, is_division, is_modulo>(*arguments[0].type.get(), *arguments[1].type.get()))
{
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, true>(
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int256, true>(
left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result);
}

return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result);
size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), result.getScale());
if (is_division && max_scale - left.getScale() + max_scale > DataTypeDecimal<typename ResultDataType::FieldType>::maxPrecision())
{
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, Int256, true>(
left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result);
}
else
{
return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType, NativeType<typename ResultDataType::FieldType>>(
left, right, col_left_const, col_right_const, col_left, col_right, col_left_size, result);
}
}

private:
// ResultDataType e.g. DataTypeDecimal<Decimal32>
template <class LeftDataType, class RightDataType, class ResultDataType, bool CalculateWith256 = false>
template <class LeftDataType, class RightDataType, class ResultDataType, class ScaleDataType, bool CalculateWith256 = false>
static ColumnPtr executeDecimalImpl(
const auto & left,
const auto & right,
Expand All @@ -152,34 +161,29 @@ struct SparkDecimalBinaryOperation
using RightFieldType = typename RightDataType::FieldType;
using ResultFieldType = typename ResultDataType::FieldType;

using NativeResultType = NativeType<ResultFieldType>;
using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;

size_t max_scale;
if constexpr (is_multiply)
max_scale = left.getScale() + right.getScale();
else
max_scale = std::max(resultDataType.getScale(), std::max(left.getScale(), right.getScale()));
size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), resultDataType.getScale());

NativeResultType scale_left = [&]
ScaleDataType scale_left = [&]
{
if constexpr (is_multiply)
return NativeResultType{1};
return ScaleDataType{1};

// cast scale same to left
auto diff_scale = max_scale - left.getScale();
if constexpr (is_division)
return DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale + max_scale);
return DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale + max_scale);
else
return DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale);
return DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale);
}();

const NativeResultType scale_right = [&]
const ScaleDataType scale_right = [&]
{
if constexpr (is_multiply)
return NativeResultType{1};
return ScaleDataType{1};
else
return DecimalUtils::scaleMultiplier<NativeResultType>(max_scale - right.getScale());
return DecimalUtils::scaleMultiplier<ScaleDataType>(max_scale - right.getScale());
}();


Expand Down Expand Up @@ -266,17 +270,19 @@ struct SparkDecimalBinaryOperation
return ColumnNullable::create(std::move(col_res), std::move(col_null_map_to));
}

template <OpCase op_case, bool CalculateWith256, typename ResultContainerType, typename NativeResultType, typename ResultDataType>
template <OpCase op_case, bool CalculateWith256, typename ResultContainerType, typename ResultDataType, typename ScaleDataType>
static static void NO_INLINE process(
const auto & a,
const auto & b,
ResultContainerType & result_container,
const NativeResultType & scale_a,
const NativeResultType & scale_b,
const ScaleDataType & scale_a,
const ScaleDataType & scale_b,
ColumnUInt8::Container & vec_null_map_to,
const ResultDataType & resultDataType,
const size_t & max_scale)
{
using NativeResultType = NativeType<typename ResultDataType::FieldType>;

size_t size;
if constexpr (op_case == OpCase::LeftConstant)
size = b.size();
Expand All @@ -303,14 +309,14 @@ struct SparkDecimalBinaryOperation
}
else if constexpr (op_case == OpCase::LeftConstant)
{
auto scaled_a = applyScaled(unwrap<op_case, OpCase::LeftConstant>(a, 0), scale_a);
ScaleDataType scaled_a = applyScaled(unwrap<op_case, OpCase::LeftConstant>(a, 0), scale_a);
for (size_t i = 0; i < size; ++i)
{
NativeResultType res;
if (calculate<CalculateWith256>(
scaled_a,
unwrap<op_case, OpCase::RightConstant>(b, i),
static_cast<NativeResultType>(0),
static_cast<ScaleDataType>(0),
scale_b,
res,
resultDataType,
Expand All @@ -322,7 +328,7 @@ struct SparkDecimalBinaryOperation
}
else if constexpr (op_case == OpCase::RightConstant)
{
auto scaled_b = applyScaled(unwrap<op_case, OpCase::RightConstant>(b, 0), scale_b);
ScaleDataType scaled_b = applyScaled(unwrap<op_case, OpCase::RightConstant>(b, 0), scale_b);

for (size_t i = 0; i < size; ++i)
{
Expand All @@ -331,7 +337,7 @@ struct SparkDecimalBinaryOperation
unwrap<op_case, OpCase::LeftConstant>(a, i),
scaled_b,
scale_a,
static_cast<NativeResultType>(0),
static_cast<ScaleDataType>(0),
res,
resultDataType,
max_scale))
Expand All @@ -343,12 +349,12 @@ struct SparkDecimalBinaryOperation
}

// ResultNativeType = Int32/64/128/256
template <bool CalculateWith256, typename LeftNativeType, typename RightNativeType, typename NativeResultType, typename ResultDataType>
template <bool CalculateWith256, typename LeftNativeType, typename RightNativeType, typename NativeResultType, typename ResultDataType, typename ScaleDataType>
static NO_SANITIZE_UNDEFINED bool calculate(
const LeftNativeType l,
const RightNativeType r,
const NativeResultType & scale_left,
const NativeResultType & scale_right,
const ScaleDataType & scale_left,
const ScaleDataType & scale_right,
NativeResultType & res,
const ResultDataType & resultDataType,
const size_t & max_scale)
Expand All @@ -361,12 +367,12 @@ struct SparkDecimalBinaryOperation
return calculateImpl<NativeResultType>(l, r, scale_left, scale_right, res, resultDataType, max_scale);
}

template <typename CalcType, typename LeftNativeType, typename RightNativeType, typename NativeResultType, typename ResultDataType>
template <typename CalcType, typename LeftNativeType, typename RightNativeType, typename NativeResultType, typename ResultDataType, typename ScaleDataType>
static NO_SANITIZE_UNDEFINED bool calculateImpl(
const LeftNativeType & l,
const RightNativeType & r,
const NativeResultType & scale_left,
const NativeResultType & scale_right,
const ScaleDataType & scale_left,
const ScaleDataType & scale_right,
NativeResultType & res,
const ResultDataType & resultDataType,
const size_t & max_scale)
Expand Down Expand Up @@ -410,13 +416,21 @@ struct SparkDecimalBinaryOperation
return elem[i].value;
}

template <typename NativeType, typename ResultNativeType>
static ResultNativeType applyScaled(const NativeType & l, const ResultNativeType & scale)
template <typename NativeType, typename ScaleType>
static ScaleType applyScaled(const NativeType & l, const ScaleType & scale)
{
if (scale > 1)
return common::mulIgnoreOverflow(l, scale);

return static_cast<ResultNativeType>(l);
return static_cast<ScaleType>(l);
}

static size_t getMaxScaled(const size_t left_scale, const size_t right_scale, const size_t result_scale)
{
if constexpr (is_multiply)
return left_scale + right_scale;
else
return std::max(result_scale, std::max(left_scale, right_scale));
}
};

Expand Down
7 changes: 7 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,13 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG & acti
String function_name = "sparkCastFloatTo" + non_nullable_output_type->getName();
function_node = toFunctionNode(actions_dag, function_name, args);
}
else if ((isDecimal(non_nullable_input_type) && substrait_type.has_decimal()))
{
args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));

function_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
}
else
{
if (isString(non_nullable_input_type) && isInt(non_nullable_output_type))
Expand Down

0 comments on commit 6b1d63c

Please sign in to comment.