Skip to content

Commit

Permalink
[Kernel] Add STARTS_WITH expression and default impl (#4007)
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [ ] Spark
- [ ] Standalone
- [ ] Flink
- [x] Kernel
- [ ] Other (fill in here)

## Description

<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->
Initial implementation of STARTS_WITH expression at this moment, we only
support b as literal expression.

This is 1/n for addressing
#2539, the logic of data
skipping will be done in the following PRs

## How was this patch tested?

<!--
If tests were added, say they were added here. Please make sure to test
the changes thoroughly including negative and positive cases if
possible.
If the changes were tested in any way other than unit tests, please
clarify how you tested step by step (ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future).
If the changes were not tested, please explain why.
-->
Added test cases in DefaultExpressionEvaluatorSuite.scala

## Does this PR introduce _any_ user-facing changes?

<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
No

---------

Co-authored-by: Xin Huang <[email protected]>
  • Loading branch information
huan233usc and huan233usc authored Jan 6, 2025
1 parent acfb8df commit b6745bb
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@
* <li>SQL semantic: <code>expr1 IS NOT DISTINCT FROM expr2</code>
* <li>Since version: 3.3.0
* </ul>
* <li>Name: <code>STARTS_WITH</code>
* <ul>
* <li>SQL semantic: <code>expr STARTS_WITH expr</code>
* <li>Since version: 3.4.0
* </ul>
* </ol>
*
* @since 3.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,18 @@ ExpressionTransformResult visitLike(final Predicate like) {
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
}

@Override
ExpressionTransformResult visitStartsWith(Predicate startsWith) {
List<ExpressionTransformResult> children =
startsWith.getChildren().stream().map(this::visit).collect(toList());
Predicate transformedExpression =
StartsWithExpressionEvaluator.validateAndTransform(
startsWith,
children.stream().map(e -> e.expression).collect(toList()),
children.stream().map(e -> e.outputType).collect(toList()));
return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN);
}

private Predicate validateIsPredicate(
Expression baseExpression, ExpressionTransformResult result) {
checkArgument(
Expand Down Expand Up @@ -610,6 +622,12 @@ ColumnVector visitLike(final Predicate like) {
children, children.stream().map(this::visit).collect(toList()));
}

@Override
ColumnVector visitStartsWith(Predicate startsWith) {
return StartsWithExpressionEvaluator.eval(
startsWith.getChildren().stream().map(this::visit).collect(toList()));
}

/**
* Utility method to evaluate inputs to the binary input expression. Also validates the
* evaluated expression result {@link ColumnVector}s are of the same size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
*/
package io.delta.kernel.defaults.internal.expressions;

import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;

import io.delta.kernel.data.ArrayValue;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.Literal;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.types.*;
import java.math.BigDecimal;
Expand Down Expand Up @@ -383,4 +385,28 @@ private ColumnVector getVector(int rowId) {
}
};
}

/**
* Checks the argument count of an expression. throws {@code unsupportedExpressionException} if
* argument count mismatched.
*/
static void checkArgsCount(Expression expr, int expectedCount, String exprName, String context) {
if (expr.getChildren().size() != expectedCount) {
throw unsupportedExpressionException(
expr, String.format("Invalid number of inputs of %s expression, %s", exprName, context));
}
}

static void checkIsStringType(DataType dataType, Expression parentExpr, String errorMessage) {
if (StringType.STRING.equals(dataType)) {
return;
}
throw unsupportedExpressionException(parentExpr, errorMessage);
}

static void checkIsLiteral(Expression expr, Expression parentExpr, String errorMessage) {
if (!(expr instanceof Literal)) {
throw unsupportedExpressionException(parentExpr, errorMessage);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ abstract class ExpressionVisitor<R> {

abstract R visitLike(Predicate predicate);

abstract R visitStartsWith(Predicate predicate);

final R visit(Expression expression) {
if (expression instanceof PartitionValueExpression) {
return visitPartitionValue((PartitionValueExpression) expression);
Expand Down Expand Up @@ -113,6 +115,8 @@ private R visitScalarExpression(ScalarExpression expression) {
return visitTimeAdd(expression);
case "LIKE":
return visitLike(new Predicate(name, children));
case "STARTS_WITH":
return visitStartsWith(new Predicate(name, children));
default:
throw new UnsupportedOperationException(
String.format("Scalar expression `%s` is not supported.", name));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (2023) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.defaults.internal.expressions;

import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;

import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.expressions.Expression;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.internal.util.Utils;
import io.delta.kernel.types.BooleanType;
import io.delta.kernel.types.DataType;
import java.util.List;

public class StartsWithExpressionEvaluator {

/** Validates and transforms the {@code starts_with} expression. */
static Predicate validateAndTransform(
Predicate startsWith,
List<Expression> childrenExpressions,
List<DataType> childrenOutputTypes) {
checkArgsCount(
startsWith,
/* expectedCount= */ 2,
startsWith.getName(),
"Example usage: STARTS_WITH(column, 'test')");
for (DataType dataType : childrenOutputTypes) {
checkIsStringType(dataType, startsWith, "'STARTS_WITH' expects STRING type inputs");
}

// TODO: support non literal as the second input of starts with.
checkIsLiteral(
childrenExpressions.get(1),
startsWith,
"'STARTS_WITH' expects literal as the second input");
return new Predicate(startsWith.getName(), childrenExpressions);
}

static ColumnVector eval(List<ColumnVector> childrenVectors) {
return new ColumnVector() {
final ColumnVector left = childrenVectors.get(0);
final ColumnVector right = childrenVectors.get(1);

@Override
public DataType getDataType() {
return BooleanType.BOOLEAN;
}

@Override
public int getSize() {
return left.getSize();
}

@Override
public void close() {
Utils.closeCloseables(left, right);
}

@Override
public boolean getBoolean(int rowId) {
if (isNullAt(rowId)) {
// The return value is undefined and can be anything, if the slot for rowId is null.
return false;
}
return left.getString(rowId).startsWith(right.getString(rowId));
}

@Override
public boolean isNullAt(int rowId) {
return left.isNullAt(rowId) || right.isNullAt(rowId);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,91 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa
"LIKE expression has invalid escape sequence"))
}

test("evaluate expression: starts with") {
val col1 = stringVector(Seq[String]("one", "two", "t%hree", "four", null, null, "%"))
val col2 = stringVector(Seq[String]("o", "t", "T", "4", "f", null, null))
val schema = new StructType()
.add("col1", StringType.STRING)
.add("col2", StringType.STRING)
val input = new DefaultColumnarBatch(col1.getSize, schema, Array(col1, col2))

val startsWithExpressionLiteral = startsWith(new Column("col1"), Literal.ofString("t%"))
val expOutputVectorLiteral =
booleanVector(Seq[BooleanJ](false, false, true, false, null, null, false))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionLiteral, BooleanType.BOOLEAN).eval(input), expOutputVectorLiteral)

val startsWithExpressionNullLiteral = startsWith(new Column("col1"), Literal.ofString(null))
val allNullVector =
booleanVector(Seq[BooleanJ](null, null, null, null, null, null, null))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionNullLiteral, BooleanType.BOOLEAN).eval(input), allNullVector)

// Two literal expressions on both sides
val startsWithExpressionAlwaysTrue = startsWith(Literal.ofString("ABC"), Literal.ofString("A"))
val allTrueVector = booleanVector(Seq[BooleanJ](true, true, true, true, true, true, true))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionAlwaysTrue, BooleanType.BOOLEAN).eval(input), allTrueVector)

val startsWithExpressionAlwaysFalse =
startsWith(Literal.ofString("ABC"), Literal.ofString("_B%"))
val allFalseVector =
booleanVector(Seq[BooleanJ](false, false, false, false, false, false, false))
checkBooleanVectors(new DefaultExpressionEvaluator(
schema, startsWithExpressionAlwaysFalse, BooleanType.BOOLEAN).eval(input), allFalseVector)

// scalastyle:off nonascii
val colUnicode = stringVector(Seq[String]("中文", "", ""))
val schemaUnicode = new StructType().add("col", StringType.STRING)
val inputUnicode = new DefaultColumnarBatch(colUnicode.getSize,
schemaUnicode, Array(colUnicode))
val startsWithExpressionUnicode = startsWith(new Column("col"), Literal.ofString(""))
val expOutputVectorLiteralUnicode = booleanVector(Seq[BooleanJ](true, true, false))
checkBooleanVectors(new DefaultExpressionEvaluator(schemaUnicode,
startsWithExpressionUnicode,
BooleanType.BOOLEAN).eval(inputUnicode), expOutputVectorLiteralUnicode)

// scalastyle:off nonascii
val colSurrogatePair = stringVector(Seq[String]("💕😉💕", "😉💕", "💕"))
val schemaSurrogatePair = new StructType().add("col", StringType.STRING)
val inputSurrogatePair = new DefaultColumnarBatch(colSurrogatePair.getSize,
schemaUnicode, Array(colSurrogatePair))
val startsWithExpressionSurrogatePair = startsWith(new Column("col"), Literal.ofString("💕"))
val expOutputVectorLiteralSurrogatePair = booleanVector(Seq[BooleanJ](true, false, true))
checkBooleanVectors(new DefaultExpressionEvaluator(schemaSurrogatePair,
startsWithExpressionSurrogatePair,
BooleanType.BOOLEAN).eval(inputSurrogatePair), expOutputVectorLiteralSurrogatePair)

val startsWithExpressionExpression = startsWith(new Column("col1"), new Column("col2"))
val e = intercept[UnsupportedOperationException] {
new DefaultExpressionEvaluator(
schema, startsWithExpressionExpression, BooleanType.BOOLEAN).eval(input)
}
assert(e.getMessage.contains("'STARTS_WITH' expects literal as the second input"))


def checkUnsupportedTypes(colType: DataType, literalType: DataType): Unit = {
val schema = new StructType()
.add("col", colType)
val expr = startsWith(new Column("col"), Literal.ofNull(literalType))
val input = new DefaultColumnarBatch(5, schema,
Array(testColumnVector(5, colType)))

val e = intercept[UnsupportedOperationException] {
new DefaultExpressionEvaluator(
schema, expr, BooleanType.BOOLEAN).eval(input)
}
assert(e.getMessage.contains("'STARTS_WITH' expects STRING type inputs"))
}

checkUnsupportedTypes(BooleanType.BOOLEAN, BooleanType.BOOLEAN)
checkUnsupportedTypes(LongType.LONG, LongType.LONG)
checkUnsupportedTypes(IntegerType.INTEGER, IntegerType.INTEGER)
checkUnsupportedTypes(StringType.STRING, BooleanType.BOOLEAN)
checkUnsupportedTypes(StringType.STRING, IntegerType.INTEGER)
checkUnsupportedTypes(StringType.STRING, LongType.LONG)
}

test("evaluate expression: comparators (=, <, <=, >, >=)") {
val ASCII_MAX_CHARACTER = '\u007F'
val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils {
new Predicate("like", children.asJava)
}

protected def startsWith(left: Expression, right: Expression): Predicate = {
new Predicate("starts_with", left, right)
}

protected def comparator(symbol: String, left: Expression, right: Expression): Predicate = {
new Predicate(symbol, left, right)
}
Expand Down

0 comments on commit b6745bb

Please sign in to comment.