From c2ef569ade9570d8827e7f4b46a8a6ec718c013c Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 10 Dec 2024 17:39:12 +1100 Subject: [PATCH] test(python): Add test for BytesIO overwritten after scan (#20240) --- crates/polars-python/src/dataframe/io.rs | 18 +++++-------- crates/polars-python/src/dataframe/serde.rs | 4 +-- crates/polars-python/src/file.rs | 28 +++++++++++++-------- py-polars/tests/unit/io/test_scan.py | 14 +++++++++++ 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index 37ba89418a28..960816231cd3 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -20,7 +20,7 @@ use crate::conversion::Wrap; use crate::error::PyPolarsErr; use crate::file::{ get_either_file, get_file_like, get_mmap_bytes_reader, get_mmap_bytes_reader_and_path, - read_if_bytesio, EitherRustPythonFile, + EitherRustPythonFile, }; use crate::prelude::{parse_cloud_options, PyCompatLevel}; @@ -37,7 +37,7 @@ impl PyDataFrame { )] pub fn read_csv( py: Python, - mut py_f: Bound, + py_f: Bound, infer_schema_length: Option, chunk_size: usize, has_header: bool, @@ -92,7 +92,6 @@ impl PyDataFrame { .collect::>() }); - py_f = read_if_bytesio(py_f); let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; let df = py.allow_threads(move || { CsvReadOptions::default() @@ -193,14 +192,12 @@ impl PyDataFrame { #[pyo3(signature = (py_f, infer_schema_length, schema, schema_overrides))] pub fn read_json( py: Python, - mut py_f: Bound, + py_f: Bound, infer_schema_length: Option, schema: Option>, schema_overrides: Option>, ) -> PyResult { assert!(infer_schema_length != Some(0)); - use crate::file::read_if_bytesio; - py_f = read_if_bytesio(py_f); let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; py.allow_threads(move || { @@ -226,12 +223,11 @@ impl PyDataFrame { #[pyo3(signature = (py_f, ignore_errors, schema, schema_overrides))] pub fn read_ndjson( py: Python, - mut py_f: Bound, + py_f: Bound, ignore_errors: bool, schema: Option>, schema_overrides: Option>, ) -> PyResult { - py_f = read_if_bytesio(py_f); let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; let mut builder = JsonReader::new(mmap_bytes_r) @@ -257,7 +253,7 @@ impl PyDataFrame { #[pyo3(signature = (py_f, columns, projection, n_rows, row_index, memory_map))] pub fn read_ipc( py: Python, - mut py_f: Bound, + py_f: Bound, columns: Option>, projection: Option>, n_rows: Option, @@ -268,7 +264,6 @@ impl PyDataFrame { name: name.into(), offset, }); - py_f = read_if_bytesio(py_f); let (mmap_bytes_r, mmap_path) = get_mmap_bytes_reader_and_path(&py_f)?; let mmap_path = if memory_map { mmap_path } else { None }; @@ -290,7 +285,7 @@ impl PyDataFrame { #[pyo3(signature = (py_f, columns, projection, n_rows, row_index, rechunk))] pub fn read_ipc_stream( py: Python, - mut py_f: Bound, + py_f: Bound, columns: Option>, projection: Option>, n_rows: Option, @@ -301,7 +296,6 @@ impl PyDataFrame { name: name.into(), offset, }); - py_f = read_if_bytesio(py_f); let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; let df = py.allow_threads(move || { IpcStreamReader::new(mmap_bytes_r) diff --git a/crates/polars-python/src/dataframe/serde.rs b/crates/polars-python/src/dataframe/serde.rs index b08d2bd5ed85..ac534cfee35d 100644 --- a/crates/polars-python/src/dataframe/serde.rs +++ b/crates/polars-python/src/dataframe/serde.rs @@ -74,9 +74,7 @@ impl PyDataFrame { /// Deserialize a file-like object containing JSON string data into a DataFrame. #[staticmethod] #[cfg(feature = "json")] - pub fn deserialize_json(py: Python, mut py_f: Bound) -> PyResult { - use crate::file::read_if_bytesio; - py_f = read_if_bytesio(py_f); + pub fn deserialize_json(py: Python, py_f: Bound) -> PyResult { let mut mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; py.allow_threads(move || { diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index ffec35a34f52..996a63ea1e48 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -372,12 +372,12 @@ pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult> } /// If the give file-like is a BytesIO, read its contents. -pub fn read_if_bytesio(py_f: Bound) -> Bound { +fn read_if_bytesio(py_f: Bound) -> Bound { if py_f.getattr("read").is_ok() { let Ok(bytes) = py_f.call_method0("getvalue") else { return py_f; }; - if bytes.downcast::().is_ok() { + if bytes.downcast::().is_ok() || bytes.downcast::().is_ok() { return bytes.clone(); } } @@ -386,24 +386,32 @@ pub fn read_if_bytesio(py_f: Bound) -> Bound { /// Create reader from PyBytes or a file-like object. To get BytesIO to have /// better performance, use read_if_bytesio() before calling this. -pub fn get_mmap_bytes_reader<'a>( - py_f: &'a Bound<'a, PyAny>, -) -> PyResult> { +pub fn get_mmap_bytes_reader(py_f: &Bound) -> PyResult> { get_mmap_bytes_reader_and_path(py_f).map(|t| t.0) } -pub fn get_mmap_bytes_reader_and_path<'a>( - py_f: &'a Bound<'a, PyAny>, -) -> PyResult<(Box, Option)> { +pub fn get_mmap_bytes_reader_and_path( + py_f: &Bound, +) -> PyResult<(Box, Option)> { + let py_f = read_if_bytesio(py_f.clone()); + // bytes object if let Ok(bytes) = py_f.downcast::() { - Ok((Box::new(Cursor::new(bytes.as_bytes())), None)) + Ok(( + Box::new(Cursor::new(MemSlice::from_arc( + bytes.as_bytes(), + Arc::new(py_f.to_object(py_f.py())), + ))), + None, + )) } // string so read file else { match get_either_buffer_or_path(py_f.to_object(py_f.py()), false)? { (EitherRustPythonFile::Rust(f), path) => Ok((Box::new(f), path)), - (EitherRustPythonFile::Py(f), path) => Ok((Box::new(f), path)), + (EitherRustPythonFile::Py(f), path) => { + Ok((Box::new(Cursor::new(f.to_memslice())), path)) + }, } } } diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index b91f2d68c05e..569c27260513 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -759,6 +759,20 @@ def test_scan_in_memory(method: str) -> None: assert_frame_equal(df.vstack(df).slice(-1, 1), result) +def test_scan_pyobject_zero_copy_buffer_mutate() -> None: + f = io.BytesIO() + + df = pl.DataFrame({"x": [1, 2, 3, 4, 5]}) + df.write_ipc(f) + f.seek(0) + + q = pl.scan_ipc(f) + assert_frame_equal(q.collect(), df) + + f.write(b"AAA") + assert_frame_equal(q.collect(), df) + + @pytest.mark.parametrize( "method", ["csv", "ndjson"],