Skip to content

Commit

Permalink
test make_ym_interval
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Mar 25, 2024
1 parent 201f322 commit 78bd925
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class ValidatorApiImpl extends ValidatorApi {
private def isPrimitiveType(dataType: DataType): Boolean = {
dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
StringType | BinaryType | _: DecimalType | DateType | TimestampType | NullType =>
StringType | BinaryType | _: DecimalType | DateType | TimestampType |
YearMonthIntervalType.DEFAULT | NullType =>
true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, OrcReadFormat, ParquetReadFormat}

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, Literal, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid}
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Lead, Literal, MakeYMInterval, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, Uuid}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -200,6 +200,7 @@ object BackendSettings extends BackendSettingsApi {
case _: StructType => Some("StructType")
case _: ArrayType => Some("ArrayType")
case _: MapType => Some("MapType")
case _: YearMonthIntervalType => Some("YearMonthIntervalType")
case _ => None
}
}
Expand Down Expand Up @@ -387,8 +388,7 @@ object BackendSettings extends BackendSettingsApi {
expr match {
// Block directly falling back the below functions by FallbackEmptySchemaRelation.
case alias: Alias => checkExpr(alias.child)
case _: Rand => true
case _: Uuid => true
case _: Rand | _: Uuid | _: MakeYMInterval => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas
case _: ArrayType =>
case _: MapType =>
case _: StructType =>
case YearMonthIntervalType.DEFAULT =>
case _: NullType =>
case _ =>
throw new GlutenNotSupportException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,38 @@ class VeloxFunctionsValidateSuite extends VeloxWholeStageTransformerSuite {
}
}

test("Test make_ym_interval function") {
runQueryAndCompare("select make_ym_interval(1, 1)") {
checkOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("select make_ym_interval(1)") {
checkOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("select make_ym_interval()") {
checkOperatorMatch[ProjectExecTransformer]
}

withTempPath {
path =>
Seq[Tuple2[Integer, Integer]]((1, 0), (-1, 1), (null, 1), (1, null))
.toDF("year", "month")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("make_ym_interval_tbl")

runQueryAndCompare("select make_ym_interval(year, month) from make_ym_interval_tbl") {
checkOperatorMatch[ProjectExecTransformer]
}

runQueryAndCompare("select make_ym_interval(year) from make_ym_interval_tbl") {
checkOperatorMatch[ProjectExecTransformer]
}
}
}

test("Test uuid function") {
runQueryAndCompare("""SELECT uuid() from lineitem limit 100""".stripMargin, false) {
checkOperatorMatch[ProjectExecTransformer]
Expand Down
1 change: 1 addition & 0 deletions cpp/core/shuffle/Utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ arrow::Result<std::vector<std::shared_ptr<arrow::DataType>>> gluten::toShuffleTy
case arrow::LargeListType::type_id:
case arrow::Decimal128Type::type_id:
case arrow::NullType::type_id:
case arrow::MonthIntervalType::type_id:
shuffleTypeId.push_back(field->type());
break;
default:
Expand Down
3 changes: 3 additions & 0 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ TypePtr SubstraitParser::parseType(const ::substrait::Type& substraitType, bool
auto scale = substraitType.decimal().scale();
return DECIMAL(precision, scale);
}
case ::substrait::Type::KindCase::kIntervalYear: {
return INTERVAL_YEAR_MONTH();
}
case ::substrait::Type::KindCase::kNothing:
return UNKNOWN();
default:
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ TypePtr getScalarType(const ::substrait::Expression::Literal& literal) {
return VARCHAR();
case ::substrait::Expression_Literal::LiteralTypeCase::kBinary:
return VARBINARY();
case ::substrait::Expression_Literal::LiteralTypeCase::kIntervalYearToMonth:
return INTERVAL_YEAR_MONTH();
default:
return nullptr;
}
Expand Down
7 changes: 5 additions & 2 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ bool SubstraitToVeloxPlanValidator::validateCast(
}

const auto& toType = SubstraitParser::parseType(castExpr.type());
if (toType->kind() == TypeKind::TIMESTAMP) {
LOG_VALIDATION_MSG("Casting to TIMESTAMP is not supported.");
if (toType->kind() == TypeKind::TIMESTAMP || toType->isIntervalYearMonth()) {
LOG_VALIDATION_MSG("Casting to " + toType->toString() + " is not supported.");
return false;
}

Expand All @@ -284,6 +284,9 @@ bool SubstraitToVeloxPlanValidator::validateCast(
LOG_VALIDATION_MSG("Casting from DATE to " + toType->toString() + " is not supported.");
return false;
}
} else if (input->type()->isIntervalYearMonth()) {
LOG_VALIDATION_MSG("Casting from INTERVAL_YEAR_MONTH is not supported.");
return false;
}
switch (input->type()->kind()) {
case TypeKind::ARRAY:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.substrait.type;

import io.substrait.proto.Type;

import java.io.Serializable;

public class IntervalYearTypeNode implements TypeNode, Serializable {

private final Boolean nullable;

public IntervalYearTypeNode(Boolean nullable) {
this.nullable = nullable;
}

@Override
public Type toProtobuf() {
Type.IntervalYear.Builder intervalYearBuilder = Type.IntervalYear.newBuilder();
if (nullable) {
intervalYearBuilder.setNullability(Type.Nullability.NULLABILITY_NULLABLE);
} else {
intervalYearBuilder.setNullability(Type.Nullability.NULLABILITY_REQUIRED);
}
Type.Builder builder = Type.newBuilder();
builder.setIntervalYear(intervalYearBuilder.build());
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public static TypeNode makeDate(Boolean nullable) {
return new DateTypeNode(nullable);
}

public static TypeNode makeIntervalYear(Boolean nullable) {
return new IntervalYearTypeNode(nullable);
}

public static TypeNode makeDecimal(Boolean nullable, Integer precision, Integer scale) {
return new DecimalTypeNode(nullable, precision, scale);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ object ConverterUtils extends Logging {
TypeBuilder.makeBinary(nullable)
case DateType =>
TypeBuilder.makeDate(nullable)
case YearMonthIntervalType.DEFAULT =>
TypeBuilder.makeIntervalYear(nullable)
case DecimalType() =>
val decimalType = datatype.asInstanceOf[DecimalType]
val precision = decimalType.precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ object ExpressionMappings {
Sig[MonthsBetween](MONTHS_BETWEEN),
Sig[DateFromUnixDate](DATE_FROM_UNIX_DATE),
Sig[MakeTimestamp](MAKE_TIMESTAMP),
Sig[MakeYMInterval](MAKE_YM_INTERVAL),
// JSON functions
Sig[GetJsonObject](GET_JSON_OBJECT),
Sig[LengthOfJsonArray](JSON_ARRAY_LENGTH),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ trait GlutenPlan extends SparkPlan with LogLevelUtil {
}
// FIXME: Use a validation-specific method to catch validation failures
TestStats.addFallBackClassName(this.getClass.toString)
logValidationMessage(s"Validation failed with exception for plan: $nodeName, due to:", e)
logValidationMessage(
s"Validation failed with exception for plan: $nodeName, due to: ${e.getMessage}",
e)
ValidationResult.notOk(e.getMessage)
} finally {
TransformerState.finishValidation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.utils
import org.apache.spark.sql.types._

import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -47,6 +47,8 @@ object SparkArrowUtil {
} else {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")
}
case YearMonthIntervalType.DEFAULT =>
new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: ArrayType => ArrowType.List.INSTANCE
case NullType => ArrowType.Null.INSTANCE
case _ =>
Expand All @@ -69,6 +71,8 @@ object SparkArrowUtil {
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
// TODO: Time unit is not handled.
case _: ArrowType.Timestamp => TimestampType
case interval: ArrowType.Interval if interval.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType.DEFAULT
case ArrowType.Null.INSTANCE => NullType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ object ExpressionNames {
final val MONTHS_BETWEEN = "months_between"
final val DATE_FROM_UNIX_DATE = "date_from_unix_date"
final val MAKE_TIMESTAMP = "make_timestamp"
final val MAKE_YM_INTERVAL = "make_ym_interval"

// JSON functions
final val GET_JSON_OBJECT = "get_json_object"
Expand Down

0 comments on commit 78bd925

Please sign in to comment.