diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index e6fbccb8966d..d09cb712ea25 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -223,6 +223,12 @@ impl ArrowWriter { Ok(()) } + /// Flushes any outstanding data and returns the underlying writer. + pub fn into_inner(mut self) -> Result { + self.flush()?; + self.writer.into_inner() + } + /// Close and finalize the underlying Parquet writer pub fn close(mut self) -> Result { self.flush()?; @@ -644,6 +650,25 @@ mod tests { roundtrip(batch, Some(SMALL_SIZE / 2)); } + fn get_bytes_after_close(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec { + let mut buffer = vec![]; + + let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap(); + writer.write(expected_batch).unwrap(); + writer.close().unwrap(); + + buffer + } + + fn get_bytes_by_into_inner( + schema: SchemaRef, + expected_batch: &RecordBatch, + ) -> Vec { + let mut writer = ArrowWriter::try_new(Vec::new(), schema, None).unwrap(); + writer.write(expected_batch).unwrap(); + writer.into_inner().unwrap() + } + #[test] fn roundtrip_bytes() { // define schema @@ -660,31 +685,28 @@ mod tests { let expected_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap(); - let mut buffer = vec![]; - - { - let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap(); - writer.write(&expected_batch).unwrap(); - writer.close().unwrap(); - } - - let cursor = Bytes::from(buffer); - let mut record_batch_reader = - ParquetRecordBatchReader::try_new(cursor, 1024).unwrap(); - - let actual_batch = record_batch_reader - .next() - .expect("No batch found") - .expect("Unable to get batch"); - - assert_eq!(expected_batch.schema(), actual_batch.schema()); - assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); - assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); - for i in 0..expected_batch.num_columns() { - let expected_data = expected_batch.column(i).data().clone(); - let actual_data = actual_batch.column(i).data().clone(); - - assert_eq!(expected_data, actual_data); + for buffer in vec![ + get_bytes_after_close(schema.clone(), &expected_batch), + get_bytes_by_into_inner(schema, &expected_batch), + ] { + let cursor = Bytes::from(buffer); + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(cursor, 1024).unwrap(); + + let actual_batch = record_batch_reader + .next() + .expect("No batch found") + .expect("Unable to get batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + for i in 0..expected_batch.num_columns() { + let expected_data = expected_batch.column(i).data().clone(); + let actual_data = actual_batch.column(i).data().clone(); + + assert_eq!(expected_data, actual_data); + } } } diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index b7bab189bb83..7af4b0fa2c94 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -62,6 +62,11 @@ impl TrackedWrite { pub fn bytes_written(&self) -> usize { self.bytes_written } + + /// Returns the underlying writer. + pub fn into_inner(self) -> W { + self.inner + } } impl Write for TrackedWrite { @@ -292,6 +297,14 @@ impl SerializedFileWriter { Ok(()) } } + + /// Writes the file footer and returns the underlying writer. + pub fn into_inner(mut self) -> Result { + self.assert_previous_writer_closed()?; + let _ = self.write_metadata()?; + + Ok(self.buf.into_inner()) + } } /// Parquet row group writer API.