From e07f24c88293b4882bb8fd688acee9809478e07c Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Tue, 11 Jun 2024 09:28:02 -0400 Subject: [PATCH] feat: Support Ansi mode in abs function (#500) * change proto msg * QueryPlanSerde with eval mode * Move eval mode * Add abs in planner * CometAbsFunc wrapper * Add error management * Add tests * Add license * spotless apply * format * Fix clippy * error msg for all spark versions * Fix benches * Use enum to ansi mode * Fix format * Add more tests * Format * Refactor * refactor * fix merge * fix merge --- core/benches/cast_from_string.rs | 2 +- core/benches/cast_numeric.rs | 2 +- .../execution/datafusion/expressions/abs.rs | 87 +++++++++++++++++++ .../execution/datafusion/expressions/cast.rs | 9 +- .../execution/datafusion/expressions/mod.rs | 29 +++++++ .../datafusion/expressions/negative.rs | 8 +- core/src/execution/datafusion/planner.rs | 19 ++-- core/src/execution/proto/expr.proto | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 14 +-- .../apache/comet/CometExpressionSuite.scala | 54 ++++++++++++ 10 files changed, 195 insertions(+), 30 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/abs.rs diff --git a/core/benches/cast_from_string.rs b/core/benches/cast_from_string.rs index 5bfaebf34..9a9ab18cc 100644 --- a/core/benches/cast_from_string.rs +++ b/core/benches/cast_from_string.rs @@ -17,7 +17,7 @@ use arrow_array::{builder::StringBuilder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; diff --git a/core/benches/cast_numeric.rs b/core/benches/cast_numeric.rs index 398be6946..35f24ce53 100644 --- a/core/benches/cast_numeric.rs +++ b/core/benches/cast_numeric.rs @@ -17,7 +17,7 @@ use arrow_array::{builder::Int32Builder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs new file mode 100644 index 000000000..4eb8c7c1e --- /dev/null +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use arrow_schema::ArrowError; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; +use datafusion_common::DataFusionError; +use datafusion_functions::math; +use std::{any::Any, sync::Arc}; + +use crate::execution::operators::ExecutionError; + +use super::{arithmetic_overflow_error, EvalMode}; + +#[derive(Debug)] +pub struct CometAbsFunc { + inner_abs_func: Arc, + eval_mode: EvalMode, + data_type_name: String, +} + +impl CometAbsFunc { + pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { + if let EvalMode::Legacy | EvalMode::Ansi = eval_mode { + Ok(Self { + inner_abs_func: math::abs().inner(), + eval_mode, + data_type_name, + }) + } else { + Err(ExecutionError::GeneralError(format!( + "Invalid EvalMode: \"{:?}\"", + eval_mode + ))) + } + } +} + +impl ScalarUDFImpl for CometAbsFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "abs" + } + + fn signature(&self) -> &Signature { + self.inner_abs_func.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner_abs_func.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match self.inner_abs_func.invoke(args) { + Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) + if msg.contains("overflow") => + { + if self.eval_mode == EvalMode::Legacy { + Ok(args[0].clone()) + } else { + let msg = arithmetic_overflow_error(&self.data_type_name).to_string(); + Err(DataFusionError::ArrowError( + ArrowError::ComputeError(msg), + trace, + )) + } + } + other => other, + } + } +} diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 045626465..4dae62dce 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -52,6 +52,8 @@ use crate::{ }, }; +use super::EvalMode; + static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); static CAST_OPTIONS: CastOptions = CastOptions { @@ -61,13 +63,6 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; -#[derive(Debug, Hash, PartialEq, Clone, Copy)] -pub enum EvalMode { - Legacy, - Ansi, - Try, -} - #[derive(Debug, Hash)] pub struct Cast { pub child: Arc, diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 05230b4c2..5d5f58e0c 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -24,6 +24,10 @@ pub mod if_expr; mod normalize_nan; pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; +use prost::DecodeError; + +use crate::{errors::CometError, execution::spark_expression}; +pub mod abs; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; @@ -39,3 +43,28 @@ pub mod temporal; pub mod unbound; mod utils; pub mod variance; + +#[derive(Debug, Hash, PartialEq, Clone, Copy)] +pub enum EvalMode { + Legacy, + Ansi, + Try, +} + +impl TryFrom for EvalMode { + type Error = DecodeError; + + fn try_from(value: i32) -> Result { + match spark_expression::EvalMode::try_from(value)? { + spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy), + spark_expression::EvalMode::Try => Ok(EvalMode::Try), + spark_expression::EvalMode::Ansi => Ok(EvalMode::Ansi), + } + } +} + +fn arithmetic_overflow_error(from_type: &str) -> CometError { + CometError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} diff --git a/core/src/execution/datafusion/expressions/negative.rs b/core/src/execution/datafusion/expressions/negative.rs index a85cde89e..cd0e9bccf 100644 --- a/core/src/execution/datafusion/expressions/negative.rs +++ b/core/src/execution/datafusion/expressions/negative.rs @@ -33,6 +33,8 @@ use std::{ sync::Arc, }; +use super::arithmetic_overflow_error; + pub fn create_negate_expr( expr: Arc, fail_on_error: bool, @@ -48,12 +50,6 @@ pub struct NegativeExpr { fail_on_error: bool, } -fn arithmetic_overflow_error(from_type: &str) -> CometError { - CometError::ArithmeticOverflow { - from_type: from_type.to_string(), - } -} - macro_rules! check_overflow { ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ let typed_array = $array diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index e51932154..d92bf5789 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -24,7 +24,6 @@ use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, execution::FunctionRegistry, - functions::math, logical_expr::Operator as DataFusionOperator, physical_expr::{ execution_props::ExecutionProps, @@ -51,6 +50,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, }; +use datafusion_expr::ScalarUDF; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use itertools::Itertools; use jni::objects::GlobalRef; @@ -65,7 +65,7 @@ use crate::{ avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, bloom_filter_might_contain::BloomFilterMightContain, - cast::{Cast, EvalMode}, + cast::Cast, checkoverflow::CheckOverflow, correlation::Correlation, covariance::Covariance, @@ -97,6 +97,8 @@ use crate::{ }, }; +use super::expressions::{abs::CometAbsFunc, EvalMode}; + // For clippy error on type_complexity. type ExecResult = Result; type PhyAggResult = Result>, ExecutionError>; @@ -356,11 +358,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { - spark_expression::EvalMode::Legacy => EvalMode::Legacy, - spark_expression::EvalMode::Try => EvalMode::Try, - spark_expression::EvalMode::Ansi => EvalMode::Ansi, - }; + let eval_mode = expr.eval_mode.try_into()?; Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } @@ -499,7 +497,12 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let return_type = child.data_type(&input_schema)?; let args = vec![child]; - let expr = ScalarFunctionExpr::new("abs", math::abs(), args, return_type); + let eval_mode = expr.eval_mode.try_into()?; + let comet_abs = Arc::new(ScalarUDF::new_from_impl(CometAbsFunc::new( + eval_mode, + return_type.to_string(), + )?)); + let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); Ok(Arc::new(expr)) } ExprStruct::CaseWhen(case_when) => { diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 5192bbd4c..093b07b3c 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -480,6 +480,7 @@ message BitwiseNot { message Abs { Expr child = 1; + EvalMode eval_mode = 2; } message Subquery { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5a0ad38d7..c1c8b5c56 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1476,15 +1476,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case Abs(child, _) => + case Abs(child, failOnErr) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val abs = - ExprOuterClass.Abs - .newBuilder() - .setChild(childExpr.get) - .build() - Some(Expr.newBuilder().setAbs(abs).build()) + val evalModeStr = + if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY + val absBuilder = ExprOuterClass.Abs.newBuilder() + absBuilder.setChild(childExpr.get) + absBuilder.setEvalMode(evalModeStr) + Some(Expr.newBuilder().setAbs(absBuilder).build()) } else { withInfo(expr, child) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 10fbc468f..a2b6edd0e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -19,6 +19,9 @@ package org.apache.comet +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -850,6 +853,57 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("abs Overflow ansi mode") { + + def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetTable(data, "tbl") { + checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { + case (Some(sparkExc), Some(cometExc)) => + val cometErrorPattern = + """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) + assert(sparkExc.getMessage.contains("overflow")) + case _ => fail("Exception should be thrown") + } + } + } + + def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") + } + } + + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue))) + testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue))) + testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue))) + testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue))) + testAbsAnsi(Seq((Float.MaxValue, Float.MinValue))) + testAbsAnsi(Seq((Double.MaxValue, Double.MinValue))) + } + } + + test("abs Overflow legacy mode") { + + def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") + } + } + } + + testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue))) + testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue))) + testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue))) + testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue))) + testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue))) + testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue))) + } + test("ceil and floor") { Seq("true", "false").foreach { dictionary => withSQLConf(