Skip to content

Commit

Permalink
Fix error with multiple nested partition columns on Iceberg
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyangli34 committed Jan 14, 2025
1 parent 223828d commit 2986072
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.iceberg;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import io.trino.spi.connector.ConnectorPartitioningHandle;
Expand All @@ -27,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
Expand Down Expand Up @@ -64,26 +66,39 @@ private static Map<Integer, List<Integer>> buildDataPaths(PartitionSpec spec)
{
Set<Integer> partitionFieldIds = spec.fields().stream().map(PartitionField::sourceId).collect(toImmutableSet());

int channel = 0;
Map<Integer, List<Integer>> fieldInfo = new HashMap<>();
for (Types.NestedField field : spec.schema().asStruct().fields()) {
// Partition fields can only be nested in a struct
if (field.type() instanceof Types.StructType nestedStruct) {
if (buildDataPaths(partitionFieldIds, nestedStruct, new ArrayDeque<>(List.of(channel)), fieldInfo)) {
channel++;
}
buildDataPaths(partitionFieldIds, nestedStruct, new ArrayDeque<>(ImmutableList.of(field.fieldId())), fieldInfo);
}
else if (field.type().isPrimitiveType() && partitionFieldIds.contains(field.fieldId())) {
fieldInfo.put(field.fieldId(), ImmutableList.of(channel));
channel++;
fieldInfo.put(field.fieldId(), ImmutableList.of(field.fieldId()));
}
}
return fieldInfo;

// assign channel for top level fields based on the order of the field id
List<Integer> sortedFieldIds = fieldInfo.keySet().stream().sorted().collect(toImmutableList());
ImmutableMap.Builder<Integer, List<Integer>> builder = ImmutableMap.builder();
Map<Integer, Integer> fieldChannels = new HashMap<>();
AtomicInteger channel = new AtomicInteger();
for (int fieldId : sortedFieldIds) {
List<Integer> dataPath = fieldInfo.get(fieldId);
int fieldChannel = fieldChannels.computeIfAbsent(dataPath.getFirst(), _ -> channel.getAndIncrement());

List<Integer> channelDataPath = ImmutableList.<Integer>builder()
.add(fieldChannel)
.addAll(dataPath.stream().skip(1).iterator())
.build();

builder.put(fieldId, channelDataPath);
}

return builder.buildOrThrow();
}

private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, ArrayDeque<Integer> currentPaths, Map<Integer, List<Integer>> dataPaths)
private static void buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, ArrayDeque<Integer> currentPaths, Map<Integer, List<Integer>> dataPaths)
{
boolean hasPartitionFields = false;
List<Types.NestedField> fields = struct.fields();
for (int fieldOrdinal = 0; fieldOrdinal < fields.size(); fieldOrdinal++) {
Types.NestedField field = fields.get(fieldOrdinal);
Expand All @@ -92,16 +107,14 @@ private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.Stru
currentPaths.addLast(fieldOrdinal);
org.apache.iceberg.types.Type type = field.type();
if (type instanceof Types.StructType nestedStruct) {
hasPartitionFields = buildDataPaths(partitionFieldIds, nestedStruct, currentPaths, dataPaths) || hasPartitionFields;
buildDataPaths(partitionFieldIds, nestedStruct, currentPaths, dataPaths);
}
// Map and List types are not supported in partitioning
if (type.isPrimitiveType() && partitionFieldIds.contains(fieldId)) {
dataPaths.put(fieldId, ImmutableList.copyOf(currentPaths));
hasPartitionFields = true;
}
currentPaths.removeLast();
}
return hasPartitionFields;
}

public long getCacheKeyHint()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
import static io.trino.plugin.iceberg.IcebergTestUtils.getFileSystemFactory;
import static io.trino.plugin.iceberg.IcebergTestUtils.getHiveMetastore;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RowType.field;
import static io.trino.spi.type.RowType.rowType;
import static io.trino.testing.MaterializedResult.DEFAULT_PRECISION;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -683,4 +687,25 @@ private BaseTable loadTable(String tableName)
{
return IcebergTestUtils.loadTable(tableName, metastore, fileSystemFactory, "hive", "tpch");
}

@Test
public void testPartitionColumns()
{
String tableName = "test_partition_columns_" + randomNameSuffix();
assertUpdate(String.format("""
CREATE TABLE %s WITH (partitioning = ARRAY[
'"r1.f1"',
'bucket(b1, 4)'
]) AS
SELECT
CAST(ROW(1, 2) AS ROW(f1 integer, f2 integer)) as r1
, CAST('b' AS VARCHAR) as b1""", tableName), 1);

assertThat(query("SELECT partition FROM \"" + tableName + "$partitions\""))
.result()
.hasTypes(ImmutableList.of(rowType(
field("r1.f1", INTEGER),
field("b1_bucket", INTEGER) )))
.matches("SELECT CAST(ROW(1, 3) AS ROW(\"r1.f1\" integer, b1_bucket integer))");
}
}

0 comments on commit 2986072

Please sign in to comment.