Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Parquet schema metadata is added to arrow table #137

Merged
merged 9 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.parquet
*.whl

# Generated by Cargo
Expand Down
4 changes: 3 additions & 1 deletion arro3-io/python/arro3/io/_io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def read_parquet(file: Path | str) -> core.RecordBatchReader:

def write_parquet(
data: types.ArrowStreamExportable | types.ArrowArrayExportable,
file: IO[bytes] | Path | str,
file: str,
*,
bloom_filter_enabled: bool | None = None,
bloom_filter_fpp: float | None = None,
Expand All @@ -274,6 +274,7 @@ def write_parquet(
key_value_metadata: dict[str, str] | None = None,
max_row_group_size: int | None = None,
max_statistics_size: int | None = None,
skip_arrow_metadata: bool = False,
write_batch_size: int | None = None,
writer_version: Literal["parquet_1_0", "parquet_2_0"] | None = None,
) -> None:
Expand Down Expand Up @@ -338,6 +339,7 @@ def write_parquet(
key_value_metadata: Sets "key_value_metadata" property (defaults to `None`).
max_row_group_size: Sets maximum number of rows in a row group (defaults to `1024 * 1024`).
max_statistics_size: Sets default max statistics size for all columns (defaults to `4096`).
skip_arrow_metadata: Parquet files generated by this writer contain embedded arrow schema by default. Set `skip_arrow_metadata` to `True`, to skip encoding the embedded metadata (defaults to `False`).
write_batch_size:
Sets write batch size (defaults to 1024).

Expand Down
31 changes: 27 additions & 4 deletions arro3-io/src/parquet.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::HashMap;
use std::fs::File;
use std::str::FromStr;
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::arrow::arrow_writer::ArrowWriterOptions;
use parquet::arrow::ArrowWriter;
Expand All @@ -14,16 +17,30 @@ use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyRecordBatchReader;

use crate::utils::{FileReader, FileWriter};
use crate::utils::FileReader;

#[pyfunction]
pub fn read_parquet(py: Python, file: FileReader) -> PyArrowResult<PyObject> {
match file {
FileReader::File(f) => {
let builder = ParquetRecordBatchReaderBuilder::try_new(f).unwrap();

let metadata = builder.schema().metadata().clone();
let reader = builder.build().unwrap();
Ok(PyRecordBatchReader::new(Box::new(reader)).to_arro3(py)?)

// Add source schema metadata onto reader's schema. The original schema is not valid
// with a given column projection, but we want to persist the source's metadata.
let arrow_schema = Arc::new(reader.schema().as_ref().clone().with_metadata(metadata));

// Create a new iterator with the arrow schema specifically
//
// Passing ParquetRecordBatchReader directly to PyRecordBatchReader::new loses schema
// metadata
//
// https://docs.rs/parquet/latest/parquet/arrow/arrow_reader/struct.ParquetRecordBatchReader.html#method.schema
// https://github.com/apache/arrow-rs/pull/5135
let iter = Box::new(RecordBatchIterator::new(reader, arrow_schema));
Ok(PyRecordBatchReader::new(iter).to_arro3(py)?)
}
FileReader::FileLike(_) => {
Err(PyTypeError::new_err("File objects not yet supported for reading parquet").into())
Expand Down Expand Up @@ -105,13 +122,14 @@ impl<'py> FromPyObject<'py> for PyColumnPath {
key_value_metadata = None,
max_row_group_size = None,
max_statistics_size = None,
skip_arrow_metadata = false,
write_batch_size = None,
writer_version = None,
))]
#[allow(clippy::too_many_arguments)]
pub(crate) fn write_parquet(
data: AnyRecordBatch,
file: FileWriter,
file: String,
bloom_filter_enabled: Option<bool>,
bloom_filter_fpp: Option<f64>,
bloom_filter_ndv: Option<u64>,
Expand All @@ -129,9 +147,12 @@ pub(crate) fn write_parquet(
key_value_metadata: Option<HashMap<String, String>>,
max_row_group_size: Option<usize>,
max_statistics_size: Option<usize>,
skip_arrow_metadata: bool,
write_batch_size: Option<usize>,
writer_version: Option<PyWriterVersion>,
) -> PyArrowResult<()> {
let file = File::create(file).map_err(|err| PyValueError::new_err(err.to_string()))?;

let mut props = WriterProperties::builder();

if let Some(writer_version) = writer_version {
Expand Down Expand Up @@ -207,7 +228,9 @@ pub(crate) fn write_parquet(

let reader = data.into_reader()?;

let writer_options = ArrowWriterOptions::new().with_properties(props.build());
let writer_options = ArrowWriterOptions::new()
.with_properties(props.build())
.with_skip_arrow_metadata(skip_arrow_metadata);
let mut writer =
ArrowWriter::try_new_with_options(file, reader.schema(), writer_options).unwrap();
for batch in reader {
Expand Down
4 changes: 3 additions & 1 deletion pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ impl PyRecordBatchReader {
});
let array_reader = Box::new(ArrayIterator::new(
array_reader,
Field::new_struct("", schema.fields().clone(), false).into(),
Field::new_struct("", schema.fields().clone(), false)
.with_metadata(schema.metadata.clone())
.into(),
));
to_stream_pycapsule(py, array_reader, requested_schema)
}
Expand Down
4 changes: 3 additions & 1 deletion pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ impl PyTable {
});
let array_reader = Box::new(ArrayIterator::new(
array_reader,
Field::new_struct("", field, false).into(),
Field::new_struct("", field, false)
.with_metadata(self.schema.metadata.clone())
.into(),
));
to_stream_pycapsule(py, array_reader, requested_schema)
}
Expand Down
25 changes: 25 additions & 0 deletions tests/core/test_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,28 @@ def test_array_export_schema_request():

retour = Array.from_arrow_pycapsule(*capsules)
assert retour.type == DataType.large_utf8()


def test_table_metadata_preserved():
metadata = {b"hello": b"world"}
pa_table = pa.table({"a": [1, 2, 3]})
pa_table = pa_table.replace_schema_metadata(metadata)

arro3_table = Table(pa_table)
assert arro3_table.schema.metadata == metadata

pa_table_retour = pa.table(arro3_table)
assert pa_table_retour.schema.metadata == metadata


def test_record_batch_reader_metadata_preserved():
metadata = {b"hello": b"world"}
pa_table = pa.table({"a": [1, 2, 3]})
pa_table = pa_table.replace_schema_metadata(metadata)
pa_reader = pa.RecordBatchReader.from_stream(pa_table)

arro3_reader = RecordBatchReader.from_stream(pa_reader)
assert arro3_reader.schema.metadata == metadata

pa_reader_retour = pa.RecordBatchReader.from_stream(arro3_reader)
assert pa_reader_retour.schema.metadata == metadata
26 changes: 26 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pyarrow as pa
import pyarrow.parquet as pq
from arro3.io import read_parquet, write_parquet


def test_copy_parquet_kv_metadata():
metadata = {"hello": "world"}
table = pa.table({"a": [1, 2, 3]})
write_parquet(
table,
"test.parquet",
key_value_metadata=metadata,
skip_arrow_metadata=True,
)

# Assert metadata was written, but arrow schema was not
pq_meta = pq.read_metadata("test.parquet").metadata
assert pq_meta[b"hello"] == b"world"
assert b"ARROW:schema" not in pq_meta.keys()

# When reading with pyarrow, kv meta gets assigned to table
pa_table = pq.read_table("test.parquet")
assert pa_table.schema.metadata[b"hello"] == b"world"

reader = read_parquet("test.parquet")
assert reader.schema.metadata[b"hello"] == b"world"