diff --git a/Cargo.lock b/Cargo.lock index b2d0f6e..2679981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,10 +112,11 @@ dependencies = [ "parquet", "pyo3", "pyo3-arrow", - "pyo3-async-runtimes", + "pyo3-async-runtimes 0.21.0", "pyo3-file", "pyo3-object_store", "thiserror", + "tokio", ] [[package]] @@ -1560,6 +1561,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "pyo3-async-runtimes" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2529f0be73ffd2be0cc43c013a640796558aa12d7ca0aab5cc14f375b4733031" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + [[package]] name = "pyo3-build-config" version = "0.22.3" @@ -1616,14 +1630,15 @@ dependencies = [ [[package]] name = "pyo3-object_store" -version = "0.1.0" -source = "git+https://github.com/developmentseed/object-store-rs?rev=922b58ff784271345ce80342cf4cd6cddce61adf#922b58ff784271345ce80342cf4cd6cddce61adf" +version = "0.1.0-beta.1" +source = "git+https://github.com/developmentseed/obstore?rev=f0dad90f1e5e157760335d1ccb4045e1f3b4f194#f0dad90f1e5e157760335d1ccb4045e1f3b4f194" dependencies = [ "futures", "object_store", "pyo3", - "pyo3-async-runtimes", + "pyo3-async-runtimes 0.22.0", "thiserror", + "url", ] [[package]] @@ -2181,9 +2196,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", diff --git a/arro3-io/Cargo.toml b/arro3-io/Cargo.toml index 4de3aaf..7926ac5 100644 --- a/arro3-io/Cargo.toml +++ b/arro3-io/Cargo.toml @@ -45,5 +45,6 @@ pyo3-async-runtimes = { workspace = true, features = [ "tokio-runtime", ], optional = true } pyo3-file = { workspace = true } -pyo3-object_store = { git = "https://github.com/developmentseed/object-store-rs", rev = "922b58ff784271345ce80342cf4cd6cddce61adf", optional = true } +pyo3-object_store = { git = "https://github.com/developmentseed/obstore", rev = "f0dad90f1e5e157760335d1ccb4045e1f3b4f194", optional = true } thiserror = { workspace = true } +tokio = "1.41.1" diff --git a/arro3-io/python/arro3/io/_io.pyi b/arro3-io/python/arro3/io/_io.pyi index 06a37a7..c331df9 100644 --- a/arro3-io/python/arro3/io/_io.pyi +++ b/arro3-io/python/arro3/io/_io.pyi @@ -1,5 +1,5 @@ from pathlib import Path -from typing import IO, Literal, Sequence +from typing import IO, Literal, Self, Sequence # Note: importing with # `from arro3.core import Array` @@ -267,6 +267,24 @@ ParquetEncoding = Literal[ ] """Allowed Parquet encodings.""" +class ParquetRecordBatchStream: + """ + A stream of [RecordBatch][core.RecordBatch] that can be polled in a sync or + async fashion. + """ + + def __aiter__(self) -> Self: + """Return `Self` as an async iterator.""" + + def __iter__(self) -> Self: + """Return `Self` as an async iterator.""" + + async def collect_async(self) -> core.Table: + """Collect all remaining batches in the stream into a table.""" + + async def __anext__(self) -> core.RecordBatch: + """Return the next record batch in the stream.""" + def read_parquet(file: IO[bytes] | Path | str) -> core.RecordBatchReader: """Read a Parquet file to an Arrow RecordBatchReader @@ -277,8 +295,10 @@ def read_parquet(file: IO[bytes] | Path | str) -> core.RecordBatchReader: The loaded Arrow data. """ -async def read_parquet_async(path: str, *, store: ObjectStore) -> core.Table: - """Read a Parquet file to an Arrow Table in an async fashion +async def read_parquet_async( + path: str, *, store: ObjectStore +) -> ParquetRecordBatchStream: + """Create an async stream of Arrow record batches from a Parquet file. Args: file: The path to the Parquet file in the given store diff --git a/arro3-io/src/lib.rs b/arro3-io/src/lib.rs index 3705acf..6f0f850 100644 --- a/arro3-io/src/lib.rs +++ b/arro3-io/src/lib.rs @@ -40,6 +40,7 @@ fn _io(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(___version))?; pyo3_object_store::register_store_module(py, m, "arro3.io")?; + pyo3_object_store::register_exceptions_module(py, m, "arro3.io")?; m.add_wrapped(wrap_pyfunction!(csv::infer_csv_schema))?; m.add_wrapped(wrap_pyfunction!(csv::read_csv))?; diff --git a/arro3-io/src/parquet.rs b/arro3-io/src/parquet.rs index 09d7fbd..47c8b54 100644 --- a/arro3-io/src/parquet.rs +++ b/arro3-io/src/parquet.rs @@ -3,22 +3,26 @@ use std::str::FromStr; use std::sync::Arc; use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_schema::SchemaRef; +use futures::StreamExt; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::arrow_writer::ArrowWriterOptions; -use parquet::arrow::async_reader::ParquetObjectReader; +use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream}; use parquet::arrow::ArrowWriter; +use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::basic::{Compression, Encoding}; use parquet::file::properties::{WriterProperties, WriterVersion}; use parquet::format::KeyValue; use parquet::schema::types::ColumnPath; -use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3_arrow::error::PyArrowResult; use pyo3_arrow::input::AnyRecordBatch; -use pyo3_arrow::{PyRecordBatchReader, PyTable}; +use pyo3_arrow::{PyRecordBatch, PyRecordBatchReader, PyTable}; use pyo3_object_store::PyObjectStore; +use tokio::sync::Mutex; -use crate::error::Arro3IoResult; +use crate::error::{Arro3IoError, Arro3IoResult}; use crate::utils::{FileReader, FileWriter}; #[pyfunction] @@ -49,21 +53,92 @@ pub fn read_parquet_async( py: Python, path: String, store: PyObjectStore, -) -> PyArrowResult { - let fut = pyo3_async_runtimes::tokio::future_into_py(py, async move { +) -> PyResult> { + pyo3_async_runtimes::tokio::future_into_py(py, async move { Ok(read_parquet_async_inner(store.into_inner(), path).await?) - })?; + }) +} + +struct PyRecordBatchWrapper(PyRecordBatch); + +impl IntoPy for PyRecordBatchWrapper { + fn into_py(self, py: Python<'_>) -> PyObject { + self.0.to_arro3(py).unwrap() + } +} + +struct PyTableWrapper(PyTable); + +impl IntoPy for PyTableWrapper { + fn into_py(self, py: Python<'_>) -> PyObject { + self.0.to_arro3(py).unwrap() + } +} + +#[pyclass(name = "ParquetRecordBatchStream")] +struct PyParquetRecordBatchStream { + stream: Arc>>, + schema: SchemaRef, +} + +#[pymethods] +impl PyParquetRecordBatchStream { + fn __aiter__(slf: Py) -> Py { + slf + } + + fn __anext__<'py>(&'py mut self, py: Python<'py>) -> PyResult> { + let stream = self.stream.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false)) + } + + fn collect_async<'py>(&'py self, py: Python<'py>) -> PyResult> { + let stream = self.stream.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, collect_stream(stream, self.schema.clone())) + } +} - Ok(fut.into()) +async fn next_stream( + stream: Arc>>, + sync: bool, +) -> PyResult { + let mut stream = stream.lock().await; + match stream.next().await { + Some(Ok(batch)) => Ok(PyRecordBatchWrapper(PyRecordBatch::new(batch))), + Some(Err(err)) => Err(Arro3IoError::ParquetError(err).into()), + None => { + // Depending on whether the iteration is sync or not, we raise either a + // StopIteration or a StopAsyncIteration + if sync { + Err(PyStopIteration::new_err("stream exhausted")) + } else { + Err(PyStopAsyncIteration::new_err("stream exhausted")) + } + } + } +} + +async fn collect_stream( + stream: Arc>>, + schema: SchemaRef, +) -> PyResult { + let mut stream = stream.lock().await; + let mut batches: Vec<_> = vec![]; + loop { + match stream.next().await { + Some(Ok(batch)) => { + batches.push(batch); + } + Some(Err(err)) => return Err(Arro3IoError::ParquetError(err).into()), + None => return Ok(PyTableWrapper(PyTable::try_new(batches, schema)?)), + }; + } } async fn read_parquet_async_inner( store: Arc, path: String, -) -> Arro3IoResult { - use futures::TryStreamExt; - use parquet::arrow::ParquetRecordBatchStreamBuilder; - +) -> Arro3IoResult { let meta = store.head(&path.into()).await?; let object_reader = ParquetObjectReader::new(store, meta); @@ -74,8 +149,10 @@ async fn read_parquet_async_inner( let arrow_schema = Arc::new(reader.schema().as_ref().clone().with_metadata(metadata)); - let batches = reader.try_collect::>().await?; - Ok(PyTable::try_new(batches, arrow_schema)?) + Ok(PyParquetRecordBatchStream { + stream: Arc::new(Mutex::new(reader)), + schema: arrow_schema, + }) } pub(crate) struct PyWriterVersion(WriterVersion); diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index 1d41d34..50a8a77 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -2,7 +2,8 @@ import pyarrow as pa import pyarrow.parquet as pq -from arro3.io import read_parquet, write_parquet +from arro3.io import read_parquet, read_parquet_async, write_parquet +from arro3.io.store import HTTPStore def test_parquet_round_trip(): @@ -42,3 +43,28 @@ def test_copy_parquet_kv_metadata(): reader = read_parquet("test.parquet") assert reader.schema.metadata[b"hello"] == b"world" + + +async def test_stream_parquet(): + from time import time + + t0 = time() + url = "https://overturemaps-us-west-2.s3.amazonaws.com/release/2024-03-12-alpha.0/theme=buildings/type=building/part-00217-4dfc75cd-2680-4d52-b5e0-f4cc9f36b267-c000.zstd.parquet" + store = HTTPStore.from_url(url) + stream = await read_parquet_async("", store=store) + t1 = time() + first = await stream.__anext__() + t2 = time() + + print(t1 - t0) + print(t2 - t1) + + test = await stream.collect_async() + len(test) + async for batch in stream: + break + + batch.num_rows + x = await stream.__anext__() + + pass