Skip to content

Commit df01d88

Browse files
authored
Parquet: Expose variantShreddingFunc() in Parquet.DataWriteBuilder (#14153)
1 parent fbb7136 commit df01d88

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,11 @@ public DataWriteBuilder createWriterFunc(
827827
return this;
828828
}
829829

830+
public DataWriteBuilder variantShreddingFunc(VariantShreddingFunction func) {
831+
appenderBuilder.variantShreddingFunc(func);
832+
return this;
833+
}
834+
830835
public DataWriteBuilder withSpec(PartitionSpec newSpec) {
831836
this.spec = newSpec;
832837
return this;

parquet/src/test/java/org/apache/iceberg/parquet/TestParquetDataWriter.java

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
import java.nio.ByteBuffer;
2626
import java.nio.file.Path;
2727
import java.util.List;
28+
import java.util.Optional;
2829
import org.apache.iceberg.DataFile;
2930
import org.apache.iceberg.FileContent;
3031
import org.apache.iceberg.FileFormat;
3132
import org.apache.iceberg.Files;
33+
import org.apache.iceberg.InternalTestHelpers;
3234
import org.apache.iceberg.MetricsConfig;
3335
import org.apache.iceberg.PartitionSpec;
3436
import org.apache.iceberg.Schema;
@@ -47,6 +49,13 @@
4749
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
4850
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
4951
import org.apache.iceberg.types.Types;
52+
import org.apache.iceberg.variants.Variant;
53+
import org.apache.iceberg.variants.VariantMetadata;
54+
import org.apache.iceberg.variants.VariantTestUtil;
55+
import org.apache.iceberg.variants.Variants;
56+
import org.apache.parquet.hadoop.ParquetFileReader;
57+
import org.apache.parquet.schema.GroupType;
58+
import org.apache.parquet.schema.MessageType;
5059
import org.junit.jupiter.api.BeforeEach;
5160
import org.junit.jupiter.api.Test;
5261
import org.junit.jupiter.api.io.TempDir;
@@ -78,25 +87,29 @@ public void createRecords() {
7887

7988
@Test
8089
public void testDataWriter() throws IOException {
90+
testDataWriter(SCHEMA, (id, name) -> null);
91+
}
92+
93+
private void testDataWriter(Schema schema, VariantShreddingFunction variantShreddingFunc)
94+
throws IOException {
8195
OutputFile file = Files.localOutput(createTempFile(temp));
8296

83-
SortOrder sortOrder = SortOrder.builderFor(SCHEMA).withOrderId(10).asc("id").build();
97+
SortOrder sortOrder = SortOrder.builderFor(schema).withOrderId(10).asc("id").build();
8498

8599
DataWriter<Record> dataWriter =
86100
Parquet.writeData(file)
87-
.schema(SCHEMA)
101+
.schema(schema)
88102
.createWriterFunc(GenericParquetWriter::create)
103+
.variantShreddingFunc(variantShreddingFunc)
89104
.overwrite()
90105
.withSpec(PartitionSpec.unpartitioned())
91106
.withSortOrder(sortOrder)
92107
.build();
93108

94-
try {
109+
try (dataWriter) {
95110
for (Record record : records) {
96111
dataWriter.write(record);
97112
}
98-
} finally {
99-
dataWriter.close();
100113
}
101114

102115
DataFile dataFile = dataWriter.toDataFile();
@@ -113,13 +126,32 @@ public void testDataWriter() throws IOException {
113126
List<Record> writtenRecords;
114127
try (CloseableIterable<Record> reader =
115128
Parquet.read(file.toInputFile())
116-
.project(SCHEMA)
117-
.createReaderFunc(fileSchema -> GenericParquetReaders.buildReader(SCHEMA, fileSchema))
129+
.project(schema)
130+
.createReaderFunc(fileSchema -> GenericParquetReaders.buildReader(schema, fileSchema))
118131
.build()) {
119132
writtenRecords = Lists.newArrayList(reader);
120133
}
121134

122-
assertThat(writtenRecords).as("Written records should match").isEqualTo(records);
135+
assertThat(writtenRecords).hasSameSizeAs(records);
136+
137+
for (int i = 0; i < records.size(); i++) {
138+
InternalTestHelpers.assertEquals(schema.asStruct(), records.get(i), writtenRecords.get(i));
139+
}
140+
141+
// Check physical Parquet schema if variant shredding function is provided
142+
Optional<Types.NestedField> variantField =
143+
schema.columns().stream()
144+
.filter(field -> field.type().equals(Types.VariantType.get()))
145+
.findFirst();
146+
147+
if (variantField.isPresent() && variantShreddingFunc != null) {
148+
try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(file.toInputFile()))) {
149+
MessageType parquetSchema = reader.getFooter().getFileMetaData().getSchema();
150+
GroupType variantType = parquetSchema.getType(variantField.get().name()).asGroupType();
151+
152+
assertThat(variantType.containsField("typed_value")).isTrue();
153+
}
154+
}
123155
}
124156

125157
@SuppressWarnings("checkstyle:AvoidEscapedUnicodeCharacters")
@@ -266,4 +298,37 @@ public void testInvalidUpperBoundBinary() throws Exception {
266298
assertThat(dataFile.lowerBounds()).as("Should have a valid lower bound").containsKey(3);
267299
assertThat(dataFile.upperBounds()).as("Should have a null upper bound").doesNotContainKey(3);
268300
}
301+
302+
@Test
303+
public void testDataWriterWithVariantShredding() throws IOException {
304+
Schema variantSchema =
305+
new Schema(
306+
ImmutableList.<Types.NestedField>builder()
307+
.addAll(SCHEMA.columns())
308+
.add(Types.NestedField.optional(4, "variant", Types.VariantType.get()))
309+
.build());
310+
311+
ByteBuffer metadataBuffer = VariantTestUtil.createMetadata(ImmutableList.of("a", "b"), true);
312+
VariantMetadata metadata = Variants.metadata(metadataBuffer);
313+
314+
ByteBuffer objectBuffer =
315+
VariantTestUtil.createObject(
316+
metadataBuffer,
317+
ImmutableMap.of(
318+
"a", Variants.of(123456789),
319+
"b", Variants.of("string")));
320+
321+
Variant variant = Variant.of(metadata, Variants.value(metadata, objectBuffer));
322+
323+
// Create records with variant data
324+
GenericRecord record = GenericRecord.create(variantSchema);
325+
326+
records =
327+
ImmutableList.of(
328+
record.copy(ImmutableMap.of("id", 1L, "variant", variant)),
329+
record.copy(ImmutableMap.of("id", 2L, "variant", variant)));
330+
331+
testDataWriter(
332+
variantSchema, (id, name) -> ParquetVariantUtil.toParquetSchema(variant.value()));
333+
}
269334
}

0 commit comments

Comments
 (0)