diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 4cf54dc8897e..d026f971e946 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -132,7 +132,10 @@ impl ArrowWriter { /// and drop any fully written `RecordBatch` pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { // validate batch schema against writer's supplied schema - if self.arrow_schema != batch.schema() { + let batch_schema = batch.schema(); + if !(Arc::ptr_eq(&self.arrow_schema, &batch_schema) + || self.arrow_schema.contains(&batch_schema)) + { return Err(ParquetError::ArrowError( "Record batch schema does not match writer schema".to_string(), )); @@ -2358,4 +2361,51 @@ mod tests { let actual = pretty_format_batches(&batches).unwrap().to_string(); assert_eq!(actual, expected); } + + #[test] + fn test_arrow_writer_metadata() { + let batch_schema = Schema::new(vec![Field::new("int32", DataType::Int32, false)]); + let file_schema = batch_schema.clone().with_metadata( + vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(), + ); + + let batch = RecordBatch::try_new( + Arc::new(batch_schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _], + ) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buf, Arc::new(file_schema), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[test] + fn test_arrow_writer_nullable() { + let batch_schema = Schema::new(vec![Field::new("int32", DataType::Int32, false)]); + let file_schema = Schema::new(vec![Field::new("int32", DataType::Int32, true)]); + let file_schema = Arc::new(file_schema); + + let batch = RecordBatch::try_new( + Arc::new(batch_schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _], + ) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buf, file_schema.clone(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let mut read = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024).unwrap(); + let back = read.next().unwrap().unwrap(); + assert_eq!(back.schema(), file_schema); + assert_ne!(back.schema(), batch.schema()); + assert_eq!(back.column(0).as_ref(), batch.column(0).as_ref()); + } }