diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedReaderBuilder.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedReaderBuilder.java index 3915ff1f1a32..398f42eb1ce7 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedReaderBuilder.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedReaderBuilder.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.IntStream; import org.apache.arrow.memory.BufferAllocator; @@ -47,6 +48,7 @@ public class VectorizedReaderBuilder extends TypeWithSchemaVisitor idToConstant; private final boolean setArrowValidityVector; private final Function>, VectorizedReader> readerFactory; + private final BiFunction convert; public VectorizedReaderBuilder( Schema expectedSchema, @@ -54,6 +56,22 @@ public VectorizedReaderBuilder( boolean setArrowValidityVector, Map idToConstant, Function>, VectorizedReader> readerFactory) { + this( + expectedSchema, + parquetSchema, + setArrowValidityVector, + idToConstant, + readerFactory, + (type, value) -> value); + } + + protected VectorizedReaderBuilder( + Schema expectedSchema, + MessageType parquetSchema, + boolean setArrowValidityVector, + Map idToConstant, + Function>, VectorizedReader> readerFactory, + BiFunction convert) { this.parquetSchema = parquetSchema; this.icebergSchema = expectedSchema; this.rootAllocator = @@ -62,6 +80,7 @@ public VectorizedReaderBuilder( this.setArrowValidityVector = setArrowValidityVector; this.idToConstant = idToConstant; this.readerFactory = readerFactory; + this.convert = convert; } @Override @@ -85,7 +104,7 @@ public VectorizedReader message( int id = field.fieldId(); VectorizedReader reader = readersById.get(id); if (idToConstant.containsKey(id)) { - reorderedFields.add(new ConstantVectorReader<>(field, idToConstant.get(id))); + reorderedFields.add(constantReader(field, idToConstant.get(id))); } else if (id == MetadataColumns.ROW_POSITION.fieldId()) { if (setArrowValidityVector) { reorderedFields.add(VectorizedArrowReader.positionsWithSetArrowValidityVector()); @@ -96,13 +115,23 @@ public VectorizedReader message( reorderedFields.add(new DeletedVectorReader()); } else if (reader != null) { reorderedFields.add(reader); - } else { + } else if (field.initialDefault() != null) { + reorderedFields.add( + constantReader(field, convert.apply(field.type(), field.initialDefault()))); + } else if (field.isOptional()) { reorderedFields.add(VectorizedArrowReader.nulls()); + } else { + throw new IllegalArgumentException( + String.format("Missing required field: %s", field.name())); } } return vectorizedReader(reorderedFields); } + private ConstantVectorReader constantReader(Types.NestedField field, T constant) { + return new ConstantVectorReader<>(field, constant); + } + protected VectorizedReader vectorizedReader(List> reorderedFields) { return readerFactory.apply(reorderedFields); } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java index e47152c79398..636ad3be7dcc 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -27,6 +27,7 @@ import org.apache.iceberg.data.DeleteFilter; import org.apache.iceberg.parquet.TypeWithSchemaVisitor; import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.spark.SparkUtil; import org.apache.parquet.schema.MessageType; import org.apache.spark.sql.catalyst.InternalRow; import org.slf4j.Logger; @@ -112,7 +113,13 @@ private static class ReaderBuilder extends VectorizedReaderBuilder { Map idToConstant, Function>, VectorizedReader> readerFactory, DeleteFilter deleteFilter) { - super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory); + super( + expectedSchema, + parquetSchema, + setArrowValidityVector, + idToConstant, + readerFactory, + SparkUtil::internalToSpark); this.deleteFilter = deleteFilter; } diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java index 4f7eab30a47d..d6e8ae773b4b 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -63,6 +63,10 @@ protected boolean supportsDefaultValues() { return false; } + protected boolean supportsNestedTypes() { + return true; + } + protected static final StructType SUPPORTED_PRIMITIVES = StructType.of( required(100, "id", LongType.get()), @@ -74,6 +78,7 @@ protected boolean supportsDefaultValues() { required(106, "d", Types.DoubleType.get()), optional(107, "date", Types.DateType.get()), required(108, "ts", Types.TimestampType.withZone()), + required(109, "ts_without_zone", Types.TimestampType.withoutZone()), required(110, "s", Types.StringType.get()), required(111, "uuid", Types.UUIDType.get()), required(112, "fixed", Types.FixedType.ofLength(7)), @@ -109,12 +114,16 @@ public void testStructWithOptionalFields() throws IOException { @Test public void testNestedStruct() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + writeAndValidate( TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); } @Test public void testArray() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -125,6 +134,8 @@ public void testArray() throws IOException { @Test public void testArrayOfStructs() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = TypeUtil.assignIncreasingFreshIds( new Schema( @@ -136,6 +147,8 @@ public void testArrayOfStructs() throws IOException { @Test public void testMap() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -149,6 +162,8 @@ public void testMap() throws IOException { @Test public void testNumericMapKey() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -160,6 +175,8 @@ public void testNumericMapKey() throws IOException { @Test public void testComplexMapKey() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = new Schema( required(0, "id", LongType.get()), @@ -179,6 +196,8 @@ public void testComplexMapKey() throws IOException { @Test public void testMapOfStructs() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + Schema schema = TypeUtil.assignIncreasingFreshIds( new Schema( @@ -193,6 +212,8 @@ public void testMapOfStructs() throws IOException { @Test public void testMixedTypes() throws IOException { + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); + StructType structType = StructType.of( required(0, "id", LongType.get()), @@ -248,17 +269,6 @@ public void testMixedTypes() throws IOException { writeAndValidate(schema); } - @Test - public void testTimestampWithoutZone() throws IOException { - Schema schema = - TypeUtil.assignIncreasingFreshIds( - new Schema( - required(0, "id", LongType.get()), - optional(1, "ts_without_zone", Types.TimestampType.withoutZone()))); - - writeAndValidate(schema); - } - @Test public void testMissingRequiredWithoutDefault() { Assumptions.assumeThat(supportsDefaultValues()).isTrue(); @@ -348,6 +358,7 @@ public void testNullDefaultValue() throws IOException { @Test public void testNestedDefaultValue() throws IOException { Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); Schema writeSchema = new Schema( @@ -391,6 +402,7 @@ public void testNestedDefaultValue() throws IOException { @Test public void testMapNestedDefaultValue() throws IOException { Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); Schema writeSchema = new Schema( @@ -443,6 +455,7 @@ public void testMapNestedDefaultValue() throws IOException { @Test public void testListNestedDefaultValue() throws IOException { Assumptions.assumeThat(supportsDefaultValues()).isTrue(); + Assumptions.assumeThat(supportsNestedTypes()).isTrue(); Schema writeSchema = new Schema( diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index d3d69e4b9d86..64d0b85625a9 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -79,6 +79,8 @@ import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType$; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.UTF8String; import scala.collection.Seq; @@ -107,13 +109,25 @@ public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row public static void assertEqualsBatch( Types.StructType struct, Iterator expected, ColumnarBatch batch) { for (int rowId = 0; rowId < batch.numRows(); rowId++) { - List fields = struct.fields(); InternalRow row = batch.getRow(rowId); Record rec = expected.next(); - for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = fields.get(i).type(); - Object expectedValue = rec.get(i); - Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + + List fields = struct.fields(); + for (int readPos = 0; readPos < fields.size(); readPos += 1) { + Types.NestedField field = fields.get(readPos); + Field writeField = rec.getSchema().getField(field.name()); + + Type fieldType = field.type(); + Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType)); + + Object expectedValue; + if (writeField != null) { + int writePos = writeField.pos(); + expectedValue = rec.get(writePos); + } else { + expectedValue = field.initialDefault(); + } + assertEqualsUnsafe(fieldType, expectedValue, actualValue); } } @@ -751,6 +765,12 @@ private static void assertEquals( for (int i = 0; i < actual.numFields(); i += 1) { StructField field = struct.fields()[i]; DataType type = field.dataType(); + // ColumnarRow.get doesn't support TimestampNTZType, causing tests to fail. the representation + // is identical to TimestampType so this uses that type to validate. + if (type instanceof TimestampNTZType) { + type = TimestampType$.MODULE$; + } + assertEquals( context + "." + field.name(), type, diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java index 5c4b216aff94..4f7864e9a160 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -49,7 +49,6 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; public class TestParquetVectorizedReads extends AvroDataTest { @@ -60,18 +59,42 @@ public class TestParquetVectorizedReads extends AvroDataTest { @Override protected void writeAndValidate(Schema schema) throws IOException { - writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true); + writeAndValidate(schema, schema); + } + + @Override + protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException { + writeAndValidate( + writeSchema, + expectedSchema, + getNumRows(), + 29714278L, + RandomData.DEFAULT_NULL_PERCENTAGE, + true, + BATCH_SIZE, + IDENTITY); + } + + @Override + protected boolean supportsDefaultValues() { + return true; + } + + @Override + protected boolean supportsNestedTypes() { + return false; } private void writeAndValidate( Schema schema, int numRecords, long seed, float nullPercentage, boolean reuseContainers) throws IOException { writeAndValidate( - schema, numRecords, seed, nullPercentage, reuseContainers, BATCH_SIZE, IDENTITY); + schema, schema, numRecords, seed, nullPercentage, reuseContainers, BATCH_SIZE, IDENTITY); } private void writeAndValidate( - Schema schema, + Schema writeSchema, + Schema expectedSchema, int numRecords, long seed, float nullPercentage, @@ -82,22 +105,23 @@ private void writeAndValidate( // Write test data assumeThat( TypeUtil.find( - schema, + writeSchema, type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())) .as("Parquet Avro cannot write non-string map keys") .isNull(); Iterable expected = - generateData(schema, numRecords, seed, nullPercentage, transform); + generateData(writeSchema, numRecords, seed, nullPercentage, transform); // write a test parquet file using iceberg writer File testFile = File.createTempFile("junit", null, temp.toFile()); assertThat(testFile.delete()).as("Delete should succeed").isTrue(); - try (FileAppender writer = getParquetWriter(schema, testFile)) { + try (FileAppender writer = getParquetWriter(writeSchema, testFile)) { writer.addAll(expected); } - assertRecordsMatch(schema, numRecords, expected, testFile, reuseContainers, batchSize); + + assertRecordsMatch(expectedSchema, numRecords, expected, testFile, reuseContainers, batchSize); } protected int getNumRows() { @@ -161,41 +185,6 @@ void assertRecordsMatch( } } - @Override - @Test - @Disabled - public void testArray() {} - - @Override - @Test - @Disabled - public void testArrayOfStructs() {} - - @Override - @Test - @Disabled - public void testMap() {} - - @Override - @Test - @Disabled - public void testNumericMapKey() {} - - @Override - @Test - @Disabled - public void testComplexMapKey() {} - - @Override - @Test - @Disabled - public void testMapOfStructs() {} - - @Override - @Test - @Disabled - public void testMixedTypes() {} - @Test @Override public void testNestedStruct() { @@ -246,10 +235,13 @@ public void testVectorizedReadsWithNewContainers() throws IOException { public void testVectorizedReadsWithReallocatedArrowBuffers() throws IOException { // With a batch size of 2, 256 bytes are allocated in the VarCharVector. By adding strings of // length 512, the vector will need to be reallocated for storing the batch. - writeAndValidate( + Schema schema = new Schema( Lists.newArrayList( - SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))), + SUPPORTED_PRIMITIVES.field("id"), SUPPORTED_PRIMITIVES.field("data"))); + writeAndValidate( + schema, + schema, 10, 0L, RandomData.DEFAULT_NULL_PERCENTAGE,