Skip to content

Commit

Permalink
[Kernel] Implement partition pruning
Browse files Browse the repository at this point in the history
## Description
Part of #2071 (Partition Pruning in Kernel). This PR integrates the different pieces added in previous PRs to have an end-to-end partition pruning.

## How was this patch tested?
Added `PartitionPruningSuite`
  • Loading branch information
vkorukanti authored Oct 3, 2023
1 parent f1eae09 commit 5f9b98e
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public Optional<ColumnVector> getSelectionVector() {

/**
* Iterator of rows that survived the filter.
*
* @return Closeable iterator of rows that survived the filter. It is responsibility of the
* caller to the close the iterator.
*/
Expand All @@ -85,7 +86,9 @@ public CloseableIterator<Row> getRows() {
@Override
public boolean hasNext() {
for (; rowId < maxRowId && nextRowId == -1; rowId++) {
if (selectionVector.get().getBoolean(rowId)) {
boolean isSelected = !selectionVector.get().isNullAt(rowId) &&
selectionVector.get().getBoolean(rowId);
if (isSelected) {
nextRowId = rowId;
rowId++;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
*/
package io.delta.kernel.internal;

import java.util.Optional;
import java.util.*;
import static java.util.stream.Collectors.toMap;

import io.delta.kernel.Scan;
import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.data.*;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.expressions.PredicateEvaluator;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
import io.delta.kernel.utils.Tuple2;
Expand All @@ -34,6 +35,9 @@
import io.delta.kernel.internal.lang.Lazy;
import io.delta.kernel.internal.types.TableSchemaSerDe;
import io.delta.kernel.internal.util.InternalSchemaUtils;
import io.delta.kernel.internal.util.PartitionUtils;
import static io.delta.kernel.internal.util.InternalUtils.checkArgument;
import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnScanFileSchema;

/**
* Implementation of {@link Scan}
Expand All @@ -54,7 +58,9 @@ public class ScanImpl
private final StructType readSchema;
private final CloseableIterator<FilteredColumnarBatch> filesIter;
private final Lazy<Tuple2<Protocol, Metadata>> protocolAndMetadata;
private final Optional<Predicate> filter;
private final Lazy<Optional<Tuple2<Predicate, Predicate>>> partitionAndDataFilters;
// Partition column names in lower case.
private final Lazy<Set<String>> partitionColumnNames;

private boolean accessedScanFiles;

Expand All @@ -70,8 +76,10 @@ public ScanImpl(
this.protocolAndMetadata = protocolAndMetadata;
this.filesIter = filesIter;
this.dataPath = dataPath;

this.filter = filter;
// Computing remaining filter requires access to metadata. We try to delay the metadata
// loading as lazily as possible, that means remaining filter computation is also lazy.
this.partitionAndDataFilters = new Lazy<>(() -> splitFilters(filter));
this.partitionColumnNames = new Lazy<>(() -> loadPartitionColNames());
}

/**
Expand All @@ -85,7 +93,8 @@ public CloseableIterator<FilteredColumnarBatch> getScanFiles(TableClient tableCl
throw new IllegalStateException("Scan files are already fetched from this instance");
}
accessedScanFiles = true;
return filesIter;

return applyPartitionPruning(tableClient, filesIter);
}

@Override
Expand All @@ -107,6 +116,82 @@ public Row getScanState(TableClient tableClient) {

@Override
public Optional<Predicate> getRemainingFilter() {
return filter;
return getDataFilters();
}

private Optional<Tuple2<Predicate, Predicate>> splitFilters(Optional<Predicate> filter) {
return filter.map(predicate ->
PartitionUtils.splitMetadataAndDataPredicates(predicate, partitionColumnNames.get()));
}

private Optional<Predicate> getDataFilters() {
return removeAlwaysTrue(partitionAndDataFilters.get().map(filters -> filters._2));
}

private Optional<Predicate> getPartitionsFilters() {
return removeAlwaysTrue(partitionAndDataFilters.get().map(filters -> filters._1));
}

/**
* Consider `ALWAYS_TRUE` as no predicate.
*/
private Optional<Predicate> removeAlwaysTrue(Optional<Predicate> predicate) {
return predicate
.filter(filter -> !filter.getName().equalsIgnoreCase("ALWAYS_TRUE"));
}

private CloseableIterator<FilteredColumnarBatch> applyPartitionPruning(
TableClient tableClient,
CloseableIterator<FilteredColumnarBatch> scanFileIter) {
Optional<Predicate> partitionPredicate = getPartitionsFilters();
if (!partitionPredicate.isPresent()) {
// There is no partition filter, return the scan file iterator as is.
return scanFileIter;
}

Metadata metadata = protocolAndMetadata.get()._2;
Set<String> partitionColNames = partitionColumnNames.get();
Map<String, DataType> partitionColNameToTypeMap = metadata.getSchema().fields().stream()
.filter(field -> partitionColNames.contains(field.getName()))
.collect(toMap(
field -> field.getName().toLowerCase(Locale.ENGLISH),
field -> field.getDataType()));

Predicate predicateOnScanFileBatch = rewritePartitionPredicateOnScanFileSchema(
partitionPredicate.get(),
partitionColNameToTypeMap);

PredicateEvaluator predicateEvaluator =
tableClient.getExpressionHandler().getPredicateEvaluator(
InternalScanFileUtils.SCAN_FILE_SCHEMA,
predicateOnScanFileBatch);

return filesIter.map(filteredScanFileBatch -> {
ColumnVector newSelectionVector = predicateEvaluator.eval(
filteredScanFileBatch.getData(),
filteredScanFileBatch.getSelectionVector());
return new FilteredColumnarBatch(
filteredScanFileBatch.getData(),
Optional.of(newSelectionVector));
});
}

/**
* Helper method to load the partition column names from the metadata.
*/
private Set<String> loadPartitionColNames() {
Metadata metadata = protocolAndMetadata.get()._2;
ArrayValue partitionColValue = metadata.getPartitionColumns();
ColumnVector partitionColNameVector = partitionColValue.getElements();
Set<String> partitionColumnNames = new HashSet<>();
for (int i = 0; i < partitionColValue.getSize(); i++) {
checkArgument(!partitionColNameVector.isNullAt(i),
"Expected a non-null partition column name");
String partitionColName = partitionColNameVector.getString(i);
checkArgument(partitionColName != null && !partitionColName.isEmpty(),
"Expected non-null and non-empty partition column name");
partitionColumnNames.add(partitionColName.toLowerCase(Locale.ENGLISH));
}
return Collections.unmodifiableSet(partitionColumnNames);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public static Tuple2<Predicate, Predicate> splitMetadataAndDataPredicates(
* String type partition values don't need any deserialization.
*
* @param predicate Predicate containing filters only on partition columns.
* @param partitionColNameTypes Map of partition columns and their types.
* @param partitionColNameTypes Map of partition column name (in lower case) to its type.
* @return
*/
public static Predicate rewritePartitionPredicateOnScanFileSchema(
Expand All @@ -170,7 +170,10 @@ private static Expression rewritePartitionColumnRef(
if (expression instanceof Column) {
Column column = (Column) expression;
String partColName = column.getNames()[0];
DataType partColType = partitionColNameTypes.get(partColName);
DataType partColType = partitionColNameTypes.get(partColName.toLowerCase(Locale.ROOT));
if (partColType == null) {
throw new IllegalArgumentException(partColName + " has no data type in metadata");
}

Expression elementAt =
new ScalarExpression(
Expand Down
Loading

0 comments on commit 5f9b98e

Please sign in to comment.