diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index 2e2966e8614..d65721717ac 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -101,6 +101,11 @@ *
  • SQL semantic: expr1 IS NOT DISTINCT FROM expr2 *
  • Since version: 3.3.0 * + *
  • Name: STARTS_WITH + * * * * @since 3.0.0 diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 2fd009b79cb..a5a8f812d22 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -299,6 +299,18 @@ ExpressionTransformResult visitLike(final Predicate like) { return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN); } + @Override + ExpressionTransformResult visitStartsWith(Predicate startsWith) { + List children = + startsWith.getChildren().stream().map(this::visit).collect(toList()); + Predicate transformedExpression = + StartsWithExpressionEvaluator.validateAndTransform( + startsWith, + children.stream().map(e -> e.expression).collect(toList()), + children.stream().map(e -> e.outputType).collect(toList())); + return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { checkArgument( @@ -610,6 +622,12 @@ ColumnVector visitLike(final Predicate like) { children, children.stream().map(this::visit).collect(toList())); } + @Override + ColumnVector visitStartsWith(Predicate startsWith) { + return StartsWithExpressionEvaluator.eval( + startsWith.getChildren().stream().map(this::visit).collect(toList())); + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index b59db8689ab..016d86bb70e 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -15,12 +15,14 @@ */ package io.delta.kernel.defaults.internal.expressions; +import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Literal; import io.delta.kernel.internal.util.Utils; import io.delta.kernel.types.*; import java.math.BigDecimal; @@ -383,4 +385,28 @@ private ColumnVector getVector(int rowId) { } }; } + + /** + * Checks the argument count of an expression. throws {@code unsupportedExpressionException} if + * argument count mismatched. + */ + static void checkArgsCount(Expression expr, int expectedCount, String exprName, String context) { + if (expr.getChildren().size() != expectedCount) { + throw unsupportedExpressionException( + expr, String.format("Invalid number of inputs of %s expression, %s", exprName, context)); + } + } + + static void checkIsStringType(DataType dataType, Expression parentExpr, String errorMessage) { + if (StringType.STRING.equals(dataType)) { + return; + } + throw unsupportedExpressionException(parentExpr, errorMessage); + } + + static void checkIsLiteral(Expression expr, Expression parentExpr, String errorMessage) { + if (!(expr instanceof Literal)) { + throw unsupportedExpressionException(parentExpr, errorMessage); + } + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index cf67d87dd23..d7d30413f9a 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -63,6 +63,8 @@ abstract class ExpressionVisitor { abstract R visitLike(Predicate predicate); + abstract R visitStartsWith(Predicate predicate); + final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); @@ -113,6 +115,8 @@ private R visitScalarExpression(ScalarExpression expression) { return visitTimeAdd(expression); case "LIKE": return visitLike(new Predicate(name, children)); + case "STARTS_WITH": + return visitStartsWith(new Predicate(name, children)); default: throw new UnsupportedOperationException( String.format("Scalar expression `%s` is not supported.", name)); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java new file mode 100644 index 00000000000..aeb0ee4dd6d --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java @@ -0,0 +1,87 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.expressions; + +import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.internal.util.Utils; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.DataType; +import java.util.List; + +public class StartsWithExpressionEvaluator { + + /** Validates and transforms the {@code starts_with} expression. */ + static Predicate validateAndTransform( + Predicate startsWith, + List childrenExpressions, + List childrenOutputTypes) { + checkArgsCount( + startsWith, + /* expectedCount= */ 2, + startsWith.getName(), + "Example usage: STARTS_WITH(column, 'test')"); + for (DataType dataType : childrenOutputTypes) { + checkIsStringType(dataType, startsWith, "'STARTS_WITH' expects STRING type inputs"); + } + + // TODO: support non literal as the second input of starts with. + checkIsLiteral( + childrenExpressions.get(1), + startsWith, + "'STARTS_WITH' expects literal as the second input"); + return new Predicate(startsWith.getName(), childrenExpressions); + } + + static ColumnVector eval(List childrenVectors) { + return new ColumnVector() { + final ColumnVector left = childrenVectors.get(0); + final ColumnVector right = childrenVectors.get(1); + + @Override + public DataType getDataType() { + return BooleanType.BOOLEAN; + } + + @Override + public int getSize() { + return left.getSize(); + } + + @Override + public void close() { + Utils.closeCloseables(left, right); + } + + @Override + public boolean getBoolean(int rowId) { + if (isNullAt(rowId)) { + // The return value is undefined and can be anything, if the slot for rowId is null. + return false; + } + return left.getString(rowId).startsWith(right.getString(rowId)); + } + + @Override + public boolean isNullAt(int rowId) { + return left.isNullAt(rowId) || right.isNullAt(rowId); + } + }; + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index bb1cc6b6809..eedea5faf62 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -582,6 +582,91 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa "LIKE expression has invalid escape sequence")) } + test("evaluate expression: starts with") { + val col1 = stringVector(Seq[String]("one", "two", "t%hree", "four", null, null, "%")) + val col2 = stringVector(Seq[String]("o", "t", "T", "4", "f", null, null)) + val schema = new StructType() + .add("col1", StringType.STRING) + .add("col2", StringType.STRING) + val input = new DefaultColumnarBatch(col1.getSize, schema, Array(col1, col2)) + + val startsWithExpressionLiteral = startsWith(new Column("col1"), Literal.ofString("t%")) + val expOutputVectorLiteral = + booleanVector(Seq[BooleanJ](false, false, true, false, null, null, false)) + checkBooleanVectors(new DefaultExpressionEvaluator( + schema, startsWithExpressionLiteral, BooleanType.BOOLEAN).eval(input), expOutputVectorLiteral) + + val startsWithExpressionNullLiteral = startsWith(new Column("col1"), Literal.ofString(null)) + val allNullVector = + booleanVector(Seq[BooleanJ](null, null, null, null, null, null, null)) + checkBooleanVectors(new DefaultExpressionEvaluator( + schema, startsWithExpressionNullLiteral, BooleanType.BOOLEAN).eval(input), allNullVector) + + // Two literal expressions on both sides + val startsWithExpressionAlwaysTrue = startsWith(Literal.ofString("ABC"), Literal.ofString("A")) + val allTrueVector = booleanVector(Seq[BooleanJ](true, true, true, true, true, true, true)) + checkBooleanVectors(new DefaultExpressionEvaluator( + schema, startsWithExpressionAlwaysTrue, BooleanType.BOOLEAN).eval(input), allTrueVector) + + val startsWithExpressionAlwaysFalse = + startsWith(Literal.ofString("ABC"), Literal.ofString("_B%")) + val allFalseVector = + booleanVector(Seq[BooleanJ](false, false, false, false, false, false, false)) + checkBooleanVectors(new DefaultExpressionEvaluator( + schema, startsWithExpressionAlwaysFalse, BooleanType.BOOLEAN).eval(input), allFalseVector) + + // scalastyle:off nonascii + val colUnicode = stringVector(Seq[String]("中文", "中", "文")) + val schemaUnicode = new StructType().add("col", StringType.STRING) + val inputUnicode = new DefaultColumnarBatch(colUnicode.getSize, + schemaUnicode, Array(colUnicode)) + val startsWithExpressionUnicode = startsWith(new Column("col"), Literal.ofString("中")) + val expOutputVectorLiteralUnicode = booleanVector(Seq[BooleanJ](true, true, false)) + checkBooleanVectors(new DefaultExpressionEvaluator(schemaUnicode, + startsWithExpressionUnicode, + BooleanType.BOOLEAN).eval(inputUnicode), expOutputVectorLiteralUnicode) + + // scalastyle:off nonascii + val colSurrogatePair = stringVector(Seq[String]("💕😉💕", "😉💕", "💕")) + val schemaSurrogatePair = new StructType().add("col", StringType.STRING) + val inputSurrogatePair = new DefaultColumnarBatch(colSurrogatePair.getSize, + schemaUnicode, Array(colSurrogatePair)) + val startsWithExpressionSurrogatePair = startsWith(new Column("col"), Literal.ofString("💕")) + val expOutputVectorLiteralSurrogatePair = booleanVector(Seq[BooleanJ](true, false, true)) + checkBooleanVectors(new DefaultExpressionEvaluator(schemaSurrogatePair, + startsWithExpressionSurrogatePair, + BooleanType.BOOLEAN).eval(inputSurrogatePair), expOutputVectorLiteralSurrogatePair) + + val startsWithExpressionExpression = startsWith(new Column("col1"), new Column("col2")) + val e = intercept[UnsupportedOperationException] { + new DefaultExpressionEvaluator( + schema, startsWithExpressionExpression, BooleanType.BOOLEAN).eval(input) + } + assert(e.getMessage.contains("'STARTS_WITH' expects literal as the second input")) + + + def checkUnsupportedTypes(colType: DataType, literalType: DataType): Unit = { + val schema = new StructType() + .add("col", colType) + val expr = startsWith(new Column("col"), Literal.ofNull(literalType)) + val input = new DefaultColumnarBatch(5, schema, + Array(testColumnVector(5, colType))) + + val e = intercept[UnsupportedOperationException] { + new DefaultExpressionEvaluator( + schema, expr, BooleanType.BOOLEAN).eval(input) + } + assert(e.getMessage.contains("'STARTS_WITH' expects STRING type inputs")) + } + + checkUnsupportedTypes(BooleanType.BOOLEAN, BooleanType.BOOLEAN) + checkUnsupportedTypes(LongType.LONG, LongType.LONG) + checkUnsupportedTypes(IntegerType.INTEGER, IntegerType.INTEGER) + checkUnsupportedTypes(StringType.STRING, BooleanType.BOOLEAN) + checkUnsupportedTypes(StringType.STRING, IntegerType.INTEGER) + checkUnsupportedTypes(StringType.STRING, LongType.LONG) + } + test("evaluate expression: comparators (=, <, <=, >, >=)") { val ASCII_MAX_CHARACTER = '\u007F' val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT)) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index 2c5fb51108f..b2448ce9250 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -49,6 +49,10 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils { new Predicate("like", children.asJava) } + protected def startsWith(left: Expression, right: Expression): Predicate = { + new Predicate("starts_with", left, right) + } + protected def comparator(symbol: String, left: Expression, right: Expression): Predicate = { new Predicate(symbol, left, right) }