Skip to content

Commit

Permalink
Spark 3.3,3.4: Align RewritePositionDeleteFilesSparkAction filter wit…
Browse files Browse the repository at this point in the history
…h Spark case sensitivity (#11710)
  • Loading branch information
huaxingao authored Dec 7, 2024
1 parent 2210e28 commit deeb04b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.iceberg.relocated.com.google.common.math.IntMath;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.iceberg.util.PropertyUtil;
Expand Down Expand Up @@ -92,11 +93,13 @@ public class RewritePositionDeleteFilesSparkAction
private int maxCommits;
private boolean partialProgressEnabled;
private RewriteJobOrder rewriteJobOrder;
private boolean caseSensitive;

RewritePositionDeleteFilesSparkAction(SparkSession spark, Table table) {
super(spark);
this.table = table;
this.rewriter = new SparkBinPackPositionDeletesRewriter(spark(), table);
this.caseSensitive = SparkUtil.caseSensitive(spark);
}

@Override
Expand Down Expand Up @@ -158,7 +161,7 @@ private StructLikeMap<List<List<PositionDeletesScanTask>>> planFileGroups() {
private CloseableIterable<PositionDeletesScanTask> planFiles(Table deletesTable) {
PositionDeletesBatchScan scan = (PositionDeletesBatchScan) deletesTable.newBatchScan();
return CloseableIterable.transform(
scan.baseTableFilter(filter).ignoreResiduals().planFiles(),
scan.baseTableFilter(filter).caseSensitive(caseSensitive).ignoreResiduals().planFiles(),
task -> (PositionDeletesScanTask) task);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.spark.sql.functions.expr;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -71,6 +72,7 @@
import org.apache.iceberg.util.StructLikeMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -228,7 +230,8 @@ public void testRewriteFilter() throws Exception {
Expression filter =
Expressions.and(
Expressions.greaterThan("c3", "0"), // should have no effect
Expressions.or(Expressions.equal("c1", 1), Expressions.equal("c1", 2)));
// "C1" should work because Spark defaults case sensitivity to false.
Expressions.or(Expressions.equal("C1", 1), Expressions.equal("C1", 2)));

Result result =
SparkActions.get(spark)
Expand All @@ -250,6 +253,19 @@ public void testRewriteFilter() throws Exception {
List<Object[]> actualDeletes = deleteRecords(table);
assertEquals("Rows must match", expectedRecords, actualRecords);
assertEquals("Position deletes must match", expectedDeletes, actualDeletes);

withSQLConf(
ImmutableMap.of(SQLConf.CASE_SENSITIVE().key(), "true"),
() -> {
assertThatThrownBy(
() ->
SparkActions.get(spark)
.rewritePositionDeletes(table)
.filter(filter)
.execute())
.isInstanceOf(ValidationException.class)
.hasMessageContaining("Cannot find field 'C1' in struct");
});
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.iceberg.relocated.com.google.common.math.IntMath;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.iceberg.util.PropertyUtil;
Expand Down Expand Up @@ -92,11 +93,13 @@ public class RewritePositionDeleteFilesSparkAction
private int maxCommits;
private boolean partialProgressEnabled;
private RewriteJobOrder rewriteJobOrder;
private boolean caseSensitive;

RewritePositionDeleteFilesSparkAction(SparkSession spark, Table table) {
super(spark);
this.table = table;
this.rewriter = new SparkBinPackPositionDeletesRewriter(spark(), table);
this.caseSensitive = SparkUtil.caseSensitive(spark);
}

@Override
Expand Down Expand Up @@ -159,7 +162,7 @@ private CloseableIterable<PositionDeletesScanTask> planFiles(Table deletesTable)
PositionDeletesBatchScan scan = (PositionDeletesBatchScan) deletesTable.newBatchScan();

return CloseableIterable.transform(
scan.baseTableFilter(filter).ignoreResiduals().planFiles(),
scan.baseTableFilter(filter).caseSensitive(caseSensitive).ignoreResiduals().planFiles(),
task -> (PositionDeletesScanTask) task);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.apache.spark.sql.functions.expr;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -72,6 +73,7 @@
import org.apache.iceberg.util.StructLikeMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -265,7 +267,8 @@ public void testRewriteFilter() throws Exception {
Expression filter =
Expressions.and(
Expressions.greaterThan("c3", "0"), // should have no effect
Expressions.or(Expressions.equal("c1", 1), Expressions.equal("c1", 2)));
// "C1" should work because Spark defaults case sensitivity to false.
Expressions.or(Expressions.equal("C1", 1), Expressions.equal("C1", 2)));

Result result =
SparkActions.get(spark)
Expand All @@ -287,6 +290,19 @@ public void testRewriteFilter() throws Exception {
List<Object[]> actualDeletes = deleteRecords(table);
assertEquals("Rows must match", expectedRecords, actualRecords);
assertEquals("Position deletes must match", expectedDeletes, actualDeletes);

withSQLConf(
ImmutableMap.of(SQLConf.CASE_SENSITIVE().key(), "true"),
() -> {
assertThatThrownBy(
() ->
SparkActions.get(spark)
.rewritePositionDeletes(table)
.filter(filter)
.execute())
.isInstanceOf(ValidationException.class)
.hasMessageContaining("Cannot find field 'C1' in struct");
});
}

@Test
Expand Down

0 comments on commit deeb04b

Please sign in to comment.