2525import java .nio .ByteBuffer ;
2626import java .nio .file .Path ;
2727import java .util .List ;
28+ import java .util .Optional ;
2829import org .apache .iceberg .DataFile ;
2930import org .apache .iceberg .FileContent ;
3031import org .apache .iceberg .FileFormat ;
3132import org .apache .iceberg .Files ;
33+ import org .apache .iceberg .InternalTestHelpers ;
3234import org .apache .iceberg .MetricsConfig ;
3335import org .apache .iceberg .PartitionSpec ;
3436import org .apache .iceberg .Schema ;
4749import org .apache .iceberg .relocated .com .google .common .collect .ImmutableMap ;
4850import org .apache .iceberg .relocated .com .google .common .collect .Lists ;
4951import org .apache .iceberg .types .Types ;
52+ import org .apache .iceberg .variants .Variant ;
53+ import org .apache .iceberg .variants .VariantMetadata ;
54+ import org .apache .iceberg .variants .VariantTestUtil ;
55+ import org .apache .iceberg .variants .Variants ;
56+ import org .apache .parquet .hadoop .ParquetFileReader ;
57+ import org .apache .parquet .schema .GroupType ;
58+ import org .apache .parquet .schema .MessageType ;
5059import org .junit .jupiter .api .BeforeEach ;
5160import org .junit .jupiter .api .Test ;
5261import org .junit .jupiter .api .io .TempDir ;
@@ -78,25 +87,29 @@ public void createRecords() {
7887
7988 @ Test
8089 public void testDataWriter () throws IOException {
90+ testDataWriter (SCHEMA , (id , name ) -> null );
91+ }
92+
93+ private void testDataWriter (Schema schema , VariantShreddingFunction variantShreddingFunc )
94+ throws IOException {
8195 OutputFile file = Files .localOutput (createTempFile (temp ));
8296
83- SortOrder sortOrder = SortOrder .builderFor (SCHEMA ).withOrderId (10 ).asc ("id" ).build ();
97+ SortOrder sortOrder = SortOrder .builderFor (schema ).withOrderId (10 ).asc ("id" ).build ();
8498
8599 DataWriter <Record > dataWriter =
86100 Parquet .writeData (file )
87- .schema (SCHEMA )
101+ .schema (schema )
88102 .createWriterFunc (GenericParquetWriter ::create )
103+ .variantShreddingFunc (variantShreddingFunc )
89104 .overwrite ()
90105 .withSpec (PartitionSpec .unpartitioned ())
91106 .withSortOrder (sortOrder )
92107 .build ();
93108
94- try {
109+ try ( dataWriter ) {
95110 for (Record record : records ) {
96111 dataWriter .write (record );
97112 }
98- } finally {
99- dataWriter .close ();
100113 }
101114
102115 DataFile dataFile = dataWriter .toDataFile ();
@@ -113,13 +126,32 @@ public void testDataWriter() throws IOException {
113126 List <Record > writtenRecords ;
114127 try (CloseableIterable <Record > reader =
115128 Parquet .read (file .toInputFile ())
116- .project (SCHEMA )
117- .createReaderFunc (fileSchema -> GenericParquetReaders .buildReader (SCHEMA , fileSchema ))
129+ .project (schema )
130+ .createReaderFunc (fileSchema -> GenericParquetReaders .buildReader (schema , fileSchema ))
118131 .build ()) {
119132 writtenRecords = Lists .newArrayList (reader );
120133 }
121134
122- assertThat (writtenRecords ).as ("Written records should match" ).isEqualTo (records );
135+ assertThat (writtenRecords ).hasSameSizeAs (records );
136+
137+ for (int i = 0 ; i < records .size (); i ++) {
138+ InternalTestHelpers .assertEquals (schema .asStruct (), records .get (i ), writtenRecords .get (i ));
139+ }
140+
141+ // Check physical Parquet schema if variant shredding function is provided
142+ Optional <Types .NestedField > variantField =
143+ schema .columns ().stream ()
144+ .filter (field -> field .type ().equals (Types .VariantType .get ()))
145+ .findFirst ();
146+
147+ if (variantField .isPresent () && variantShreddingFunc != null ) {
148+ try (ParquetFileReader reader = ParquetFileReader .open (ParquetIO .file (file .toInputFile ()))) {
149+ MessageType parquetSchema = reader .getFooter ().getFileMetaData ().getSchema ();
150+ GroupType variantType = parquetSchema .getType (variantField .get ().name ()).asGroupType ();
151+
152+ assertThat (variantType .containsField ("typed_value" )).isTrue ();
153+ }
154+ }
123155 }
124156
125157 @ SuppressWarnings ("checkstyle:AvoidEscapedUnicodeCharacters" )
@@ -266,4 +298,37 @@ public void testInvalidUpperBoundBinary() throws Exception {
266298 assertThat (dataFile .lowerBounds ()).as ("Should have a valid lower bound" ).containsKey (3 );
267299 assertThat (dataFile .upperBounds ()).as ("Should have a null upper bound" ).doesNotContainKey (3 );
268300 }
301+
302+ @ Test
303+ public void testDataWriterWithVariantShredding () throws IOException {
304+ Schema variantSchema =
305+ new Schema (
306+ ImmutableList .<Types .NestedField >builder ()
307+ .addAll (SCHEMA .columns ())
308+ .add (Types .NestedField .optional (4 , "variant" , Types .VariantType .get ()))
309+ .build ());
310+
311+ ByteBuffer metadataBuffer = VariantTestUtil .createMetadata (ImmutableList .of ("a" , "b" ), true );
312+ VariantMetadata metadata = Variants .metadata (metadataBuffer );
313+
314+ ByteBuffer objectBuffer =
315+ VariantTestUtil .createObject (
316+ metadataBuffer ,
317+ ImmutableMap .of (
318+ "a" , Variants .of (123456789 ),
319+ "b" , Variants .of ("string" )));
320+
321+ Variant variant = Variant .of (metadata , Variants .value (metadata , objectBuffer ));
322+
323+ // Create records with variant data
324+ GenericRecord record = GenericRecord .create (variantSchema );
325+
326+ records =
327+ ImmutableList .of (
328+ record .copy (ImmutableMap .of ("id" , 1L , "variant" , variant )),
329+ record .copy (ImmutableMap .of ("id" , 2L , "variant" , variant )));
330+
331+ testDataWriter (
332+ variantSchema , (id , name ) -> ParquetVariantUtil .toParquetSchema (variant .value ()));
333+ }
269334}
0 commit comments