Skip to content

Commit

Permalink
fix spark32
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Mar 19, 2024
1 parent 21adec7 commit e0edb7f
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -372,38 +372,5 @@ class GlutenClickHouseTPCHParquetAQESuite
)(df => {})
}
}

test("test timestampadd") {
val sql =
"""
|select
| timestampadd(MICROSECOND, 2, o_orderdate),
| timestampadd(MILLISECOND, 2, o_orderdate),
| timestampadd(SECOND, 2, o_orderdate),
| timestampadd(MINUTE, 2, o_orderdate),
| timestampadd(HOUR, 2, o_orderdate),
| timestampadd(DAY, 2, o_orderdate),
| timestampadd(WEEK, 2, o_orderdate),
| timestampadd(MONTH, 2, o_orderdate),
| timestampadd(QUARTER, 2, o_orderdate),
| timestampadd(YEAR, 2, o_orderdate),
| timestampadd(DAYOFYEAR, 2, o_orderdate),
| dateadd(MICROSECOND, 2, o_orderdate),
| dateadd(MILLISECOND, 2, o_orderdate),
| dateadd(SECOND, 2, o_orderdate),
| dateadd(MINUTE, 2, o_orderdate),
| dateadd(HOUR, 2, o_orderdate),
| dateadd(DAY, 2, o_orderdate),
| dateadd(WEEK, 2, o_orderdate),
| dateadd(MONTH, 2, o_orderdate),
| dateadd(QUARTER, 2, o_orderdate),
| dateadd(YEAR, 2, o_orderdate),
| dateadd(DAYOFYEAR, 2, o_orderdate)
|
|from orders
|order by o_orderdate limit 10
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class FunctionParserTimestampAdd : public FunctionParser
"Unsupported timezone_field argument, should be a string literal, but: {}",
timezone_field.DebugString());

const auto & unit = unit_field.value().literal().string();
const auto & unit = Poco::toUpper(unit_field.value().literal().string());
auto timezone = timezone_field.value().literal().string();

std::string ch_function_name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.exception.GlutenNotSupportException
import io.glutenproject.execution.{ColumnarToRowExecBase, WholeStageTransformer}
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.test.TestStats
import io.glutenproject.utils.{DecimalArithmeticUtil, PlanUtil}

Expand Down Expand Up @@ -523,15 +524,20 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap),
n
)
case timestampAdd: TimestampAdd =>
case timestampAdd if timestampAdd.getClass.getSimpleName.equals("TimestampAdd") =>
val extract = SparkShimLoader.getSparkShims.extractExpressionTimestampAddUnit(timestampAdd)
if (extract.isEmpty) {
throw new UnsupportedOperationException(s"Not support expression TimestampAdd.")
}
val add = timestampAdd.asInstanceOf[BinaryExpression]
TimestampAddTransform(
substraitExprName,
replaceWithExpressionTransformerInternal(timestampAdd.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(
timestampAdd.right,
attributeSeq,
expressionsMap),
timestampAdd
extract.get.head,
replaceWithExpressionTransformerInternal(add.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(add.right, attributeSeq, expressionsMap),
extract.get.last,
add.dataType,
add.nullable
)
case e: Transformable =>
val childrenTransformers =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ object ExpressionMappings {
Sig[TruncDate](TRUNC),
Sig[TruncTimestamp](DATE_TRUNC),
Sig[GetTimestamp](GET_TIMESTAMP),
Sig[TimestampAdd](TIMESTAMP_ADD),
Sig[NextDay](NEXT_DAY),
Sig[LastDay](LAST_DAY),
Sig[MonthsBetween](MONTHS_BETWEEN),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ package io.glutenproject.expression
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DataType

import com.google.common.collect.Lists

case class TimestampAddTransform(
substraitExprName: String,
unit: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: TimestampAdd)
timeZoneId: String,
dataType: DataType,
nullable: Boolean)
extends ExpressionTransformer {

override def doTransform(args: java.lang.Object): ExpressionNode = {
Expand All @@ -36,18 +39,15 @@ case class TimestampAddTransform(
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionId = ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(
substraitExprName,
original.children.map(_.dataType),
FunctionConfig.REQ)
ConverterUtils.makeFuncName(substraitExprName, Seq(), FunctionConfig.REQ)
)

val expressionNodes = Lists.newArrayList(
ExpressionBuilder.makeStringLiteral(original.unit),
ExpressionBuilder.makeStringLiteral(unit),
leftNode,
rightNode,
ExpressionBuilder.makeStringLiteral(original.timeZoneId.getOrElse("")))
val outputType = ConverterUtils.getTypeNode(original.dataType, original.nullable)
ExpressionBuilder.makeStringLiteral(timeZoneId))
val outputType = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, outputType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("to_timestamp exception mode")
// Replaced by a gluten test to pass timezone through config.
.exclude("from_unixtime")
.exclude("test timestamp add")
enableSuite[GlutenDecimalExpressionSuite]
enableSuite[GlutenHashExpressionsSuite]
enableSuite[GlutenIntervalExpressionsSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{TimeZoneUTC, getZoneId}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.{LocalDateTime, ZoneId}
import java.util.{Calendar, Locale, TimeZone}

import java.util.concurrent.TimeUnit._
import java.util.{Calendar, Locale, TimeZone}

class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTrait {
override def testIntegralInput(testFunc: Number => Unit): Unit = {
Expand Down Expand Up @@ -58,6 +59,26 @@ class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTr
// checkResult(Int.MinValue.toLong - 100)
}

private def timestampLiteral(s: String, sdf: SimpleDateFormat, dt: DataType): Literal = {
dt match {
case _: TimestampType =>
Literal(new Timestamp(sdf.parse(s).getTime))

case _: TimestampNTZType =>
Literal(LocalDateTime.parse(s.replace(" ", "T")))
}
}

private def timestampAnswer(s: String, sdf: SimpleDateFormat, dt: DataType): Any = {
dt match {
case _: TimestampType =>
DateTimeUtils.fromJavaTimestamp(new Timestamp(sdf.parse(s).getTime))

case _: TimestampNTZType =>
LocalDateTime.parse(s.replace(" ", "T"))
}
}

testGluten("TIMESTAMP_MICROS") {
def testIntegralFunc(value: Number): Unit = {
checkEvaluation(MicrosToTimestamp(Literal(value)), value.longValue())
Expand Down Expand Up @@ -472,4 +493,26 @@ class GlutenDateExpressionsSuite extends DateExpressionsSuite with GlutenTestsTr
}
}
}

test("test timestamp add") {
// Check case-insensitivity
checkEvaluation(
TimestampAdd("SECOND", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2022-02-15 12:57:01"))
checkEvaluation(
TimestampAdd("MINUTE", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2022-02-15 12:58:00"))
checkEvaluation(
TimestampAdd("HOUR", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2022-02-15 13:57:00"))
checkEvaluation(
TimestampAdd("DAY", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2022-02-16 12:57:00"))
checkEvaluation(
TimestampAdd("MONTH", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2022-03-15 12:57:00"))
checkEvaluation(
TimestampAdd("YEAR", Literal(1), Literal(Timestamp.valueOf("2022-02-15 12:57:00"))),
Timestamp.valueOf("2023-02-15 12:57:00"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,7 @@ trait SparkShims {
keyGroupedPartitioning: Option[Seq[Expression]],
filteredPartitions: Seq[Seq[InputPartition]],
outputPartitioning: Partitioning): Seq[InputPartition] = filteredPartitions.flatten

def extractExpressionTimestampAddUnit(timestampAdd: Expression): Option[Seq[String]] =
Option.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package io.glutenproject.sql.shims.spark33
import io.glutenproject.GlutenConfig
import io.glutenproject.execution.datasource.GlutenParquetWriterInjects
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.expression.ExpressionNames.TIMESTAMP_ADD
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleDependency, ShuffleUtils, SparkContext, SparkEnv, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark._
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.TimestampFormatter
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FileFormat, FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
Expand Down Expand Up @@ -71,7 +71,9 @@ class Spark33Shims extends SparkShims {
Sig[SplitPart](ExpressionNames.SPLIT_PART),
Sig[Sec](ExpressionNames.SEC),
Sig[Csc](ExpressionNames.CSC),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL))
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[TimestampAdd](TIMESTAMP_ADD)
)
}

override def convertPartitionTransforms(
Expand Down Expand Up @@ -262,4 +264,12 @@ class Spark33Shims extends SparkShims {
}
override def getCommonPartitionValues(batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] =
null

override def extractExpressionTimestampAddUnit(exp: Expression): Option[Seq[String]] = {
exp match {
case timestampAdd: TimestampAdd =>
Option.apply(Seq(timestampAdd.unit, timestampAdd.timeZoneId.getOrElse("")))
case _ => Option.empty
}
}
}

0 comments on commit e0edb7f

Please sign in to comment.