From e0edb7f67ecc8876627b46cd690cd96b5de5ae80 Mon Sep 17 00:00:00 2001 From: loneylee Date: Tue, 19 Mar 2024 16:13:56 +0800 Subject: [PATCH] fix spark32 --- .../GlutenClickHouseTPCHParquetAQESuite.scala | 33 ------------- .../scalar_function_parser/timestampAdd.cpp | 2 +- .../expression/ExpressionConverter.scala | 20 +++++--- .../expression/ExpressionMappings.scala | 1 - .../expression/TimestampAddTransform.scala | 18 +++---- .../utils/velox/VeloxTestSettings.scala | 1 + .../GlutenDateExpressionsSuite.scala | 47 ++++++++++++++++++- .../glutenproject/sql/shims/SparkShims.scala | 3 ++ .../sql/shims/spark33/Spark33Shims.scala | 18 +++++-- 9 files changed, 86 insertions(+), 57 deletions(-) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala index f940198115742..f7a74f9b16b86 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala @@ -372,38 +372,5 @@ class GlutenClickHouseTPCHParquetAQESuite )(df => {}) } } - - test("test timestampadd") { - val sql = - """ - |select - | timestampadd(MICROSECOND, 2, o_orderdate), - | timestampadd(MILLISECOND, 2, o_orderdate), - | timestampadd(SECOND, 2, o_orderdate), - | timestampadd(MINUTE, 2, o_orderdate), - | timestampadd(HOUR, 2, o_orderdate), - | timestampadd(DAY, 2, o_orderdate), - | timestampadd(WEEK, 2, o_orderdate), - | timestampadd(MONTH, 2, o_orderdate), - | timestampadd(QUARTER, 2, o_orderdate), - | timestampadd(YEAR, 2, o_orderdate), - | timestampadd(DAYOFYEAR, 2, o_orderdate), - | dateadd(MICROSECOND, 2, o_orderdate), - | dateadd(MILLISECOND, 2, o_orderdate), - | dateadd(SECOND, 2, o_orderdate), - | dateadd(MINUTE, 2, o_orderdate), - | dateadd(HOUR, 2, o_orderdate), - | dateadd(DAY, 2, o_orderdate), - | dateadd(WEEK, 2, o_orderdate), - | dateadd(MONTH, 2, o_orderdate), - | dateadd(QUARTER, 2, o_orderdate), - | dateadd(YEAR, 2, o_orderdate), - | dateadd(DAYOFYEAR, 2, o_orderdate) - | - |from orders - |order by o_orderdate limit 10 - |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) - } } // scalastyle:off line.size.limit diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp index 21a0989d9a5cd..d76431c0a0964 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp @@ -61,7 +61,7 @@ class FunctionParserTimestampAdd : public FunctionParser "Unsupported timezone_field argument, should be a string literal, but: {}", timezone_field.DebugString()); - const auto & unit = unit_field.value().literal().string(); + const auto & unit = Poco::toUpper(unit_field.value().literal().string()); auto timezone = timezone_field.value().literal().string(); std::string ch_function_name; diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala index dfe4e552cbb93..0bda58b7df51d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.exception.GlutenNotSupportException import io.glutenproject.execution.{ColumnarToRowExecBase, WholeStageTransformer} +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.test.TestStats import io.glutenproject.utils.{DecimalArithmeticUtil, PlanUtil} @@ -523,15 +524,20 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap), n ) - case timestampAdd: TimestampAdd => + case timestampAdd if timestampAdd.getClass.getSimpleName.equals("TimestampAdd") => + val extract = SparkShimLoader.getSparkShims.extractExpressionTimestampAddUnit(timestampAdd) + if (extract.isEmpty) { + throw new UnsupportedOperationException(s"Not support expression TimestampAdd.") + } + val add = timestampAdd.asInstanceOf[BinaryExpression] TimestampAddTransform( substraitExprName, - replaceWithExpressionTransformerInternal(timestampAdd.left, attributeSeq, expressionsMap), - replaceWithExpressionTransformerInternal( - timestampAdd.right, - attributeSeq, - expressionsMap), - timestampAdd + extract.get.head, + replaceWithExpressionTransformerInternal(add.left, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(add.right, attributeSeq, expressionsMap), + extract.get.last, + add.dataType, + add.nullable ) case e: Transformable => val childrenTransformers = diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index fc05a873650af..c0ff0e8707ebe 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -173,7 +173,6 @@ object ExpressionMappings { Sig[TruncDate](TRUNC), Sig[TruncTimestamp](DATE_TRUNC), Sig[GetTimestamp](GET_TIMESTAMP), - Sig[TimestampAdd](TIMESTAMP_ADD), Sig[NextDay](NEXT_DAY), Sig[LastDay](LAST_DAY), Sig[MonthsBetween](MONTHS_BETWEEN), diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/TimestampAddTransform.scala b/gluten-core/src/main/scala/io/glutenproject/expression/TimestampAddTransform.scala index c5ada1d6d276f..cc4d7bb472cf1 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/TimestampAddTransform.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/TimestampAddTransform.scala @@ -19,15 +19,18 @@ package io.glutenproject.expression import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType import com.google.common.collect.Lists case class TimestampAddTransform( substraitExprName: String, + unit: String, left: ExpressionTransformer, right: ExpressionTransformer, - original: TimestampAdd) + timeZoneId: String, + dataType: DataType, + nullable: Boolean) extends ExpressionTransformer { override def doTransform(args: java.lang.Object): ExpressionNode = { @@ -36,18 +39,15 @@ case class TimestampAddTransform( val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] val functionId = ExpressionBuilder.newScalarFunction( functionMap, - ConverterUtils.makeFuncName( - substraitExprName, - original.children.map(_.dataType), - FunctionConfig.REQ) + ConverterUtils.makeFuncName(substraitExprName, Seq(), FunctionConfig.REQ) ) val expressionNodes = Lists.newArrayList( - ExpressionBuilder.makeStringLiteral(original.unit), + ExpressionBuilder.makeStringLiteral(unit), leftNode, rightNode, - ExpressionBuilder.makeStringLiteral(original.timeZoneId.getOrElse(""))) - val outputType = ConverterUtils.getTypeNode(original.dataType, original.nullable) + ExpressionBuilder.makeStringLiteral(timeZoneId)) + val outputType = ConverterUtils.getTypeNode(dataType, nullable) ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, outputType) } } diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index bc311716818d5..42be98ed4b629 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -137,6 +137,7 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("to_timestamp exception mode") // Replaced by a gluten test to pass timezone through config. .exclude("from_unixtime") + .exclude("test timestamp add") enableSuite[GlutenDecimalExpressionSuite] enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenIntervalExpressionsSuite] diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDateExpressionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDateExpressionsSuite.scala index e726dcea18c7f..813743d47f63a 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDateExpressionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDateExpressionsSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{TimeZoneUTC, getZoneId} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -29,8 +29,9 @@ import org.apache.spark.unsafe.types.UTF8String import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.time.{LocalDateTime, ZoneId} -import java.util.{Calendar, Locale, TimeZone} + import java.util.concurrent.TimeUnit._ +import java.util.{Calendar, Locale, TimeZone} class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTrait { override def testIntegralInput(testFunc: Number => Unit): Unit = { @@ -58,6 +59,26 @@ class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTr // checkResult(Int.MinValue.toLong - 100) } + private def timestampLiteral(s: String, sdf: SimpleDateFormat, dt: DataType): Literal = { + dt match { + case _: TimestampType => + Literal(new Timestamp(sdf.parse(s).getTime)) + + case _: TimestampNTZType => + Literal(LocalDateTime.parse(s.replace(" ", "T"))) + } + } + + private def timestampAnswer(s: String, sdf: SimpleDateFormat, dt: DataType): Any = { + dt match { + case _: TimestampType => + DateTimeUtils.fromJavaTimestamp(new Timestamp(sdf.parse(s).getTime)) + + case _: TimestampNTZType => + LocalDateTime.parse(s.replace(" ", "T")) + } + } + testGluten("TIMESTAMP_MICROS") { def testIntegralFunc(value: Number): Unit = { checkEvaluation(MicrosToTimestamp(Literal(value)), value.longValue()) @@ -472,4 +493,26 @@ class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTr } } } + + test("test timestamp add") { + // Check case-insensitivity + checkEvaluation( + TimestampAdd("SECOND", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2022-02-15 12:57:01")) + checkEvaluation( + TimestampAdd("MINUTE", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2022-02-15 12:58:00")) + checkEvaluation( + TimestampAdd("HOUR", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2022-02-15 13:57:00")) + checkEvaluation( + TimestampAdd("DAY", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2022-02-16 12:57:00")) + checkEvaluation( + TimestampAdd("MONTH", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2022-03-15 12:57:00")) + checkEvaluation( + TimestampAdd("YEAR", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))), + Timestamp.valueOf("2023-02-15 12:57:00")) + } } diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index 64ed1b866c0c1..0b390b97993cd 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -177,4 +177,7 @@ trait SparkShims { keyGroupedPartitioning: Option[Seq[Expression]], filteredPartitions: Seq[Seq[InputPartition]], outputPartitioning: Partitioning): Seq[InputPartition] = filteredPartitions.flatten + + def extractExpressionTimestampAddUnit(timestampAdd: Expression): Option[Seq[String]] = + Option.empty } diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 8e770325f9b30..91372eb0c4442 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -19,11 +19,11 @@ package io.glutenproject.sql.shims.spark33 import io.glutenproject.GlutenConfig import io.glutenproject.execution.datasource.GlutenParquetWriterInjects import io.glutenproject.expression.{ExpressionNames, Sig} +import io.glutenproject.expression.ExpressionNames.TIMESTAMP_ADD import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} -import org.apache.spark.{ShuffleDependency, ShuffleUtils, SparkContext, SparkEnv, SparkException, TaskContext, TaskContextUtils} +import org.apache.spark._ import org.apache.spark.scheduler.TaskInfo -import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan} -import org.apache.spark.sql.execution.datasources.{BucketingUtils, FileFormat, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan @@ -71,7 +71,9 @@ class Spark33Shims extends SparkShims { Sig[SplitPart](ExpressionNames.SPLIT_PART), Sig[Sec](ExpressionNames.SEC), Sig[Csc](ExpressionNames.CSC), - Sig[Empty2Null](ExpressionNames.EMPTY2NULL)) + Sig[Empty2Null](ExpressionNames.EMPTY2NULL), + Sig[TimestampAdd](TIMESTAMP_ADD) + ) } override def convertPartitionTransforms( @@ -262,4 +264,12 @@ class Spark33Shims extends SparkShims { } override def getCommonPartitionValues(batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] = null + + override def extractExpressionTimestampAddUnit(exp: Expression): Option[Seq[String]] = { + exp match { + case timestampAdd: TimestampAdd => + Option.apply(Seq(timestampAdd.unit, timestampAdd.timeZoneId.getOrElse(""))) + case _ => Option.empty + } + } }