diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java
index 2e2966e8614..d65721717ac 100644
--- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java
+++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java
@@ -101,6 +101,11 @@
*
SQL semantic: expr1 IS NOT DISTINCT FROM expr2
* Since version: 3.3.0
*
+ * Name: STARTS_WITH
+ *
+ * - SQL semantic:
expr STARTS_WITH expr
+ * - Since version: 3.4.0
+ *
*
*
* @since 3.0.0
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
index 2fd009b79cb..a5a8f812d22 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java
@@ -299,6 +299,18 @@ ExpressionTransformResult visitLike(final Predicate like) {
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
}
+ @Override
+ ExpressionTransformResult visitStartsWith(Predicate startsWith) {
+ List children =
+ startsWith.getChildren().stream().map(this::visit).collect(toList());
+ Predicate transformedExpression =
+ StartsWithExpressionEvaluator.validateAndTransform(
+ startsWith,
+ children.stream().map(e -> e.expression).collect(toList()),
+ children.stream().map(e -> e.outputType).collect(toList()));
+ return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
+ }
+
private Predicate validateIsPredicate(
Expression baseExpression, ExpressionTransformResult result) {
checkArgument(
@@ -610,6 +622,12 @@ ColumnVector visitLike(final Predicate like) {
children, children.stream().map(this::visit).collect(toList()));
}
+ @Override
+ ColumnVector visitStartsWith(Predicate startsWith) {
+ return StartsWithExpressionEvaluator.eval(
+ startsWith.getChildren().stream().map(this::visit).collect(toList()));
+ }
+
/**
* Utility method to evaluate inputs to the binary input expression. Also validates the
* evaluated expression result {@link ColumnVector}s are of the same size.
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java
index b59db8689ab..016d86bb70e 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java
@@ -15,12 +15,14 @@
*/
package io.delta.kernel.defaults.internal.expressions;
+import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
import io.delta.kernel.data.ArrayValue;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.expressions.Expression;
+import io.delta.kernel.expressions.Literal;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.types.*;
import java.math.BigDecimal;
@@ -383,4 +385,28 @@ private ColumnVector getVector(int rowId) {
}
};
}
+
+ /**
+ * Checks the argument count of an expression. throws {@code unsupportedExpressionException} if
+ * argument count mismatched.
+ */
+ static void checkArgsCount(Expression expr, int expectedCount, String exprName, String context) {
+ if (expr.getChildren().size() != expectedCount) {
+ throw unsupportedExpressionException(
+ expr, String.format("Invalid number of inputs of %s expression, %s", exprName, context));
+ }
+ }
+
+ static void checkIsStringType(DataType dataType, Expression parentExpr, String errorMessage) {
+ if (StringType.STRING.equals(dataType)) {
+ return;
+ }
+ throw unsupportedExpressionException(parentExpr, errorMessage);
+ }
+
+ static void checkIsLiteral(Expression expr, Expression parentExpr, String errorMessage) {
+ if (!(expr instanceof Literal)) {
+ throw unsupportedExpressionException(parentExpr, errorMessage);
+ }
+ }
}
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
index cf67d87dd23..d7d30413f9a 100644
--- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java
@@ -63,6 +63,8 @@ abstract class ExpressionVisitor {
abstract R visitLike(Predicate predicate);
+ abstract R visitStartsWith(Predicate predicate);
+
final R visit(Expression expression) {
if (expression instanceof PartitionValueExpression) {
return visitPartitionValue((PartitionValueExpression) expression);
@@ -113,6 +115,8 @@ private R visitScalarExpression(ScalarExpression expression) {
return visitTimeAdd(expression);
case "LIKE":
return visitLike(new Predicate(name, children));
+ case "STARTS_WITH":
+ return visitStartsWith(new Predicate(name, children));
default:
throw new UnsupportedOperationException(
String.format("Scalar expression `%s` is not supported.", name));
diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java
new file mode 100644
index 00000000000..aeb0ee4dd6d
--- /dev/null
+++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright (2023) The Delta Lake Project Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.delta.kernel.defaults.internal.expressions;
+
+import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;
+
+import io.delta.kernel.data.ColumnVector;
+import io.delta.kernel.expressions.Expression;
+import io.delta.kernel.expressions.Predicate;
+import io.delta.kernel.internal.util.Utils;
+import io.delta.kernel.types.BooleanType;
+import io.delta.kernel.types.DataType;
+import java.util.List;
+
+public class StartsWithExpressionEvaluator {
+
+ /** Validates and transforms the {@code starts_with} expression. */
+ static Predicate validateAndTransform(
+ Predicate startsWith,
+ List childrenExpressions,
+ List childrenOutputTypes) {
+ checkArgsCount(
+ startsWith,
+ /* expectedCount= */ 2,
+ startsWith.getName(),
+ "Example usage: STARTS_WITH(column, 'test')");
+ for (DataType dataType : childrenOutputTypes) {
+ checkIsStringType(dataType, startsWith, "'STARTS_WITH' expects STRING type inputs");
+ }
+
+ // TODO: support non literal as the second input of starts with.
+ checkIsLiteral(
+ childrenExpressions.get(1),
+ startsWith,
+ "'STARTS_WITH' expects literal as the second input");
+ return new Predicate(startsWith.getName(), childrenExpressions);
+ }
+
+ static ColumnVector eval(List childrenVectors) {
+ return new ColumnVector() {
+ final ColumnVector left = childrenVectors.get(0);
+ final ColumnVector right = childrenVectors.get(1);
+
+ @Override
+ public DataType getDataType() {
+ return BooleanType.BOOLEAN;
+ }
+
+ @Override
+ public int getSize() {
+ return left.getSize();
+ }
+
+ @Override
+ public void close() {
+ Utils.closeCloseables(left, right);
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ if (isNullAt(rowId)) {
+ // The return value is undefined and can be anything, if the slot for rowId is null.
+ return false;
+ }
+ return left.getString(rowId).startsWith(right.getString(rowId));
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ return left.isNullAt(rowId) || right.isNullAt(rowId);
+ }
+ };
+ }
+}
diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
index bb1cc6b6809..eedea5faf62 100644
--- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
+++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala
@@ -582,6 +582,91 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
"LIKE expression has invalid escape sequence"))
}
+ test("evaluate expression: starts with") {
+ val col1 = stringVector(Seq[String]("one", "two", "t%hree", "four", null, null, "%"))
+ val col2 = stringVector(Seq[String]("o", "t", "T", "4", "f", null, null))
+ val schema = new StructType()
+ .add("col1", StringType.STRING)
+ .add("col2", StringType.STRING)
+ val input = new DefaultColumnarBatch(col1.getSize, schema, Array(col1, col2))
+
+ val startsWithExpressionLiteral = startsWith(new Column("col1"), Literal.ofString("t%"))
+ val expOutputVectorLiteral =
+ booleanVector(Seq[BooleanJ](false, false, true, false, null, null, false))
+ checkBooleanVectors(new DefaultExpressionEvaluator(
+ schema, startsWithExpressionLiteral, BooleanType.BOOLEAN).eval(input), expOutputVectorLiteral)
+
+ val startsWithExpressionNullLiteral = startsWith(new Column("col1"), Literal.ofString(null))
+ val allNullVector =
+ booleanVector(Seq[BooleanJ](null, null, null, null, null, null, null))
+ checkBooleanVectors(new DefaultExpressionEvaluator(
+ schema, startsWithExpressionNullLiteral, BooleanType.BOOLEAN).eval(input), allNullVector)
+
+ // Two literal expressions on both sides
+ val startsWithExpressionAlwaysTrue = startsWith(Literal.ofString("ABC"), Literal.ofString("A"))
+ val allTrueVector = booleanVector(Seq[BooleanJ](true, true, true, true, true, true, true))
+ checkBooleanVectors(new DefaultExpressionEvaluator(
+ schema, startsWithExpressionAlwaysTrue, BooleanType.BOOLEAN).eval(input), allTrueVector)
+
+ val startsWithExpressionAlwaysFalse =
+ startsWith(Literal.ofString("ABC"), Literal.ofString("_B%"))
+ val allFalseVector =
+ booleanVector(Seq[BooleanJ](false, false, false, false, false, false, false))
+ checkBooleanVectors(new DefaultExpressionEvaluator(
+ schema, startsWithExpressionAlwaysFalse, BooleanType.BOOLEAN).eval(input), allFalseVector)
+
+ // scalastyle:off nonascii
+ val colUnicode = stringVector(Seq[String]("中文", "中", "文"))
+ val schemaUnicode = new StructType().add("col", StringType.STRING)
+ val inputUnicode = new DefaultColumnarBatch(colUnicode.getSize,
+ schemaUnicode, Array(colUnicode))
+ val startsWithExpressionUnicode = startsWith(new Column("col"), Literal.ofString("中"))
+ val expOutputVectorLiteralUnicode = booleanVector(Seq[BooleanJ](true, true, false))
+ checkBooleanVectors(new DefaultExpressionEvaluator(schemaUnicode,
+ startsWithExpressionUnicode,
+ BooleanType.BOOLEAN).eval(inputUnicode), expOutputVectorLiteralUnicode)
+
+ // scalastyle:off nonascii
+ val colSurrogatePair = stringVector(Seq[String]("💕😉💕", "😉💕", "💕"))
+ val schemaSurrogatePair = new StructType().add("col", StringType.STRING)
+ val inputSurrogatePair = new DefaultColumnarBatch(colSurrogatePair.getSize,
+ schemaUnicode, Array(colSurrogatePair))
+ val startsWithExpressionSurrogatePair = startsWith(new Column("col"), Literal.ofString("💕"))
+ val expOutputVectorLiteralSurrogatePair = booleanVector(Seq[BooleanJ](true, false, true))
+ checkBooleanVectors(new DefaultExpressionEvaluator(schemaSurrogatePair,
+ startsWithExpressionSurrogatePair,
+ BooleanType.BOOLEAN).eval(inputSurrogatePair), expOutputVectorLiteralSurrogatePair)
+
+ val startsWithExpressionExpression = startsWith(new Column("col1"), new Column("col2"))
+ val e = intercept[UnsupportedOperationException] {
+ new DefaultExpressionEvaluator(
+ schema, startsWithExpressionExpression, BooleanType.BOOLEAN).eval(input)
+ }
+ assert(e.getMessage.contains("'STARTS_WITH' expects literal as the second input"))
+
+
+ def checkUnsupportedTypes(colType: DataType, literalType: DataType): Unit = {
+ val schema = new StructType()
+ .add("col", colType)
+ val expr = startsWith(new Column("col"), Literal.ofNull(literalType))
+ val input = new DefaultColumnarBatch(5, schema,
+ Array(testColumnVector(5, colType)))
+
+ val e = intercept[UnsupportedOperationException] {
+ new DefaultExpressionEvaluator(
+ schema, expr, BooleanType.BOOLEAN).eval(input)
+ }
+ assert(e.getMessage.contains("'STARTS_WITH' expects STRING type inputs"))
+ }
+
+ checkUnsupportedTypes(BooleanType.BOOLEAN, BooleanType.BOOLEAN)
+ checkUnsupportedTypes(LongType.LONG, LongType.LONG)
+ checkUnsupportedTypes(IntegerType.INTEGER, IntegerType.INTEGER)
+ checkUnsupportedTypes(StringType.STRING, BooleanType.BOOLEAN)
+ checkUnsupportedTypes(StringType.STRING, IntegerType.INTEGER)
+ checkUnsupportedTypes(StringType.STRING, LongType.LONG)
+ }
+
test("evaluate expression: comparators (=, <, <=, >, >=)") {
val ASCII_MAX_CHARACTER = '\u007F'
val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))
diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala
index 2c5fb51108f..b2448ce9250 100644
--- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala
+++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala
@@ -49,6 +49,10 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils {
new Predicate("like", children.asJava)
}
+ protected def startsWith(left: Expression, right: Expression): Predicate = {
+ new Predicate("starts_with", left, right)
+ }
+
protected def comparator(symbol: String, left: Expression, right: Expression): Predicate = {
new Predicate(symbol, left, right)
}