From 69dd9d9105fccc29953bfb326932e3436997427b Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Thu, 16 Nov 2023 19:25:06 +0100 Subject: [PATCH] add tests to check rust py03 writer --- crates/deltalake-core/src/operations/mod.rs | 4 +- python/deltalake/_internal.pyi | 5 +- python/deltalake/writer.py | 75 +++++--- python/src/lib.rs | 64 ++++--- python/tests/test_writer.py | 179 ++++++++++++++++---- 5 files changed, 247 insertions(+), 80 deletions(-) diff --git a/crates/deltalake-core/src/operations/mod.rs b/crates/deltalake-core/src/operations/mod.rs index 4b50b32d4b..2fdb2fd3a9 100644 --- a/crates/deltalake-core/src/operations/mod.rs +++ b/crates/deltalake-core/src/operations/mod.rs @@ -82,7 +82,9 @@ impl DeltaOps { /// try from uri with storage options pub async fn try_from_uri_with_storage_options(uri: impl AsRef, storage_options: HashMap) -> DeltaResult { - let mut table = DeltaTableBuilder::from_uri(uri).with_storage_options(storage_options).build()?; + let mut table = DeltaTableBuilder::from_uri(uri) + .with_storage_options(storage_options) + .build()?; // We allow for uninitialized locations, since we may want to create the table match table.load().await { Ok(_) => Ok(table.into()), diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 0b7875e220..7ff50ecd2b 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -142,10 +142,13 @@ def write_new_deltalake( def write_to_deltalake( table_uri: str, data: pyarrow.RecordBatchReader, - partition_by: List[str], + partition_by: Optional[List[str]], mode: str, max_rows_per_group: int, overwrite_schema: bool, + name: Optional[str], + description: Optional[str], + configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], ) -> None: ... def batch_distinct(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: ... diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index b58287e10d..47651dd16b 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -167,17 +167,66 @@ def write_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) - if engine == "pyarrow": + __enforce_append_only(table=table, configuration=configuration, mode=mode) + + if isinstance(partition_by, str): + partition_by = [partition_by] + + if table: + table.update_incremental() + + if engine == "rust": + # Easier to do this check in Python than rust + if table is not None and mode == "ignore": + return + ### COMMENTS ### + # - Don't check partition columns if they are the same, this is done on the rust side implicility + # - Consolidate the recordbatch reader part with the new update + # - Add overwrite schema functionality in rust writer + # - Figure out how to add name, description and configuration to the correct metadata in transaction + + if isinstance(data, RecordBatchReader): + batch_iter = data + elif isinstance(data, pa.RecordBatch): + batch_iter = [data] + elif isinstance(data, pa.Table): + batch_iter = data.to_reader() + elif isinstance(data, ds.Dataset): + batch_iter = data.scanner().to_reader() + elif isinstance(data, pd.DataFrame): + batch_iter = pa.Table.from_pandas(data).to_reader() + else: + batch_iter = data + + if schema is None: + if isinstance(batch_iter, RecordBatchReader): + schema = batch_iter.schema + elif isinstance(batch_iter, Iterable): + raise ValueError("You must provide schema if data is Iterable") + + data = RecordBatchReader.from_batches(schema, (batch for batch in batch_iter)) + _write_to_deltalake( + table_uri=table_uri, + data=data, + partition_by=partition_by, + mode=mode, + max_rows_per_group=max_rows_per_group, + overwrite_schema=overwrite_schema, + name=name, + description=description, + configuration=configuration, + storage_options=storage_options, + ) + if table: + table.update_incremental() + + elif engine == "pyarrow": if _has_pandas and isinstance(data, pd.DataFrame): if schema is not None: data = pa.Table.from_pandas(data, schema=schema) else: data, schema = delta_arrow_schema_from_pandas(data) - # We need to write against the latest table version - if table: - table.update_incremental() - if schema is None: if isinstance(data, RecordBatchReader): schema = data.schema @@ -193,11 +242,6 @@ def write_deltalake( filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) - __enforce_append_only(table=table, configuration=configuration, mode=mode) - - if isinstance(partition_by, str): - partition_by = [partition_by] - if table: # already exists if schema != table.schema().to_pyarrow( as_large_types=large_dtypes @@ -206,7 +250,6 @@ def write_deltalake( "Schema of data does not match table schema\n" f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" ) - if mode == "error": raise AssertionError("DeltaTable already exists.") elif mode == "ignore": @@ -396,15 +439,7 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: ) table.update_incremental() else: - _write_to_deltalake( - table_uri=table_uri, - data=data.to_reader(), - partition_by=partition_by, - mode=mode, - max_rows_per_group=max_rows_per_group, - overwrite_schema=overwrite_schema, - storage_options=storage_options, - ) + raise ValueError("Only `pyarrow` or `rust` are valid inputs for the engine.") def __enforce_append_only( diff --git a/python/src/lib.rs b/python/src/lib.rs index dff380decf..9fc6a0742c 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -43,6 +43,7 @@ use deltalake::DeltaTableBuilder; use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyFrozenSet, PyType}; +use serde_json::Value; use crate::error::DeltaProtocolError; use crate::error::PythonError; @@ -1132,41 +1133,58 @@ impl From<&PyAddAction> for Add { fn write_to_deltalake( table_uri: String, data: PyArrowType, - // schema: Option>, - partition_by: Vec, // or Vec + // schema: Option>, // maybe do the schema casting on python side mode: String, - // max_partitions: i64, - // max_rows_per_file: i64, - // min_rows_per_group: i64, max_rows_per_group: i64, - // name: Option, - // description: Option, - // configuration: Option>>, overwrite_schema: bool, + partition_by: Option>, + name: Option, + description: Option, + configuration: Option>>, storage_options: Option>, ) -> PyResult<()> { - - // let schema = data.0.schema(); let batches = data.0.map(|batch| batch.unwrap()).collect::>(); - // let batches = data.0; + let save_mode = save_mode_from_str(&mode)?; - let mode = save_mode_from_str(&mode)?; - // let new_schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?; + let mut metadata: HashMap = HashMap::new(); - // let existing_schema = self._table.get_schema().map_err(PythonError::from)?; + if let Some(name) = name { + metadata.insert("name".to_string(), name.into()); + } + + if let Some(description) = description { + metadata.insert("description".to_string(), description.into()); + } - // let schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?; + if let Some(configuration) = configuration { + metadata.insert("configuration".to_string(), json!(configuration)); + } + // // This should be done when the table can not be loaded ... + // match save_mode { + // SaveMode::Ignore => { + // return Ok(()) + // } + // _ => () + // } let options = storage_options.clone().unwrap_or_default(); - let table = rt()?.block_on(DeltaOps::try_from_uri_with_storage_options(&table_uri, options)).map_err(PythonError::from)?; + let table = rt()? + .block_on(DeltaOps::try_from_uri_with_storage_options( + &table_uri, options, + )) + .map_err(PythonError::from)?; - let builder = table + let mut builder = table .write(batches) - .with_save_mode(mode) + .with_save_mode(save_mode) .with_overwrite_schema(overwrite_schema) - .with_write_batch_size(max_rows_per_group as usize) - .with_partition_columns(partition_by); - + .with_metadata(metadata) + .with_write_batch_size(max_rows_per_group as usize); + + if let Some(partition_columns) = partition_by { + builder = builder.with_partition_columns(partition_columns); + } + rt()? .block_on(builder.into_future()) .map_err(PythonError::from)?; @@ -1174,9 +1192,7 @@ fn write_to_deltalake( Ok(()) } - - - +use serde_json::json; #[pyfunction] #[allow(clippy::too_many_arguments)] fn write_new_deltalake( diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index d048f8b79b..d2a2abefa6 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -6,7 +6,7 @@ import threading from datetime import date, datetime from math import inf -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Literal from unittest.mock import Mock import pyarrow as pa @@ -17,7 +17,7 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaProtocolError +from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -29,24 +29,30 @@ _has_pandas = True +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.skip(reason="Waiting on #570") -def test_handle_existing(tmp_path: pathlib.Path, sample_data: pa.Table): +def test_handle_existing( + tmp_path: pathlib.Path, sample_data: pa.Table, engine: Literal["pyarrow", "rust"] +): # if uri points to a non-empty directory that isn't a delta table, error tmp_path p = tmp_path / "hello.txt" p.write_text("hello") with pytest.raises(OSError) as exception: - write_deltalake(tmp_path, sample_data, mode="overwrite") + write_deltalake(tmp_path, sample_data, mode="overwrite", engine=engine) assert "directory is not empty" in str(exception) -def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_basic( + tmp_path: pathlib.Path, sample_data: pa.Table, engine: Literal["pyarrow", "rust"] +): # Check we can create the subdirectory tmp_path = tmp_path / "path" / "to" / "table" start_time = datetime.now().timestamp() - write_deltalake(tmp_path, sample_data) + write_deltalake(tmp_path, sample_data, engine=engine) end_time = datetime.now().timestamp() assert ("0" * 20 + ".json") in os.listdir(tmp_path / "_delta_log") @@ -71,7 +77,8 @@ def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): assert modification_time < end_time -def test_roundtrip_nulls(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_nulls(tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"]): data = pa.table({"x": pa.array([None, None, 1, 2], type=pa.int64())}) # One row group will have values, one will be all nulls. # The first will have None in min and max stats, so we need to handle that. @@ -91,6 +98,7 @@ def test_roundtrip_nulls(tmp_path: pathlib.Path): min_rows_per_group=2, max_rows_per_group=2, mode="overwrite", + engine=engine, ) delta_table = DeltaTable(tmp_path) @@ -105,11 +113,23 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, bad_data, mode=mode) + write_deltalake(existing_table, bad_data, mode=mode, engine="pyarrow") table_uri = existing_table._table.table_uri() with pytest.raises(ValueError): - write_deltalake(table_uri, bad_data, mode=mode) + write_deltalake(table_uri, bad_data, mode=mode, engine="pyarrow") + + +@pytest.mark.parametrize("mode", ["append", "overwrite"]) +def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): + bad_data = pa.table({"x": pa.array([1, 2, 3])}) + + with pytest.raises(DeltaError): + write_deltalake(existing_table, bad_data, mode=mode, engine="rust") + + table_uri = existing_table._table.table_uri() + with pytest.raises(DeltaError): + write_deltalake(table_uri, bad_data, mode=mode, engine="rust") def test_update_schema(existing_table: DeltaTable): @@ -125,12 +145,39 @@ def test_update_schema(existing_table: DeltaTable): assert existing_table.schema().to_pyarrow() == new_data.schema -def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): +# def test_update_schema_rust_writer(existing_table: DeltaTable): # Test fails +# new_data = pa.table({"x": pa.array([1, 2, 3])}) + +# with pytest.raises(DeltaError): +# write_deltalake( +# existing_table, +# new_data, +# mode="append", +# overwrite_schema=True, +# engine="rust", +# ) + +# write_deltalake( +# existing_table, new_data, mode="overwrite", overwrite_schema=True, engine="rust" +# ) + +# read_data = existing_table.to_pyarrow_table() +# assert new_data == read_data +# assert existing_table.schema().to_pyarrow() == new_data.schema + + +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_local_path( + tmp_path: pathlib.Path, + sample_data: pa.Table, + monkeypatch, + engine: Literal["pyarrow", "rust"], +): monkeypatch.chdir(tmp_path) # Make tmp_path the working directory (tmp_path / "path/to/table").mkdir(parents=True) local_path = "./path/to/table" - write_deltalake(local_path, sample_data) + write_deltalake(local_path, sample_data, engine=engine) delta_table = DeltaTable(local_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -138,13 +185,15 @@ def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): assert table == sample_data -def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table, engine): write_deltalake( tmp_path, sample_data, name="test_name", description="test_desc", configuration={"delta.appendOnly": "false", "foo": "bar"}, + engine=engine, ) delta_table = DeltaTable(tmp_path) @@ -156,6 +205,7 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): assert metadata.configuration == {"delta.appendOnly": "false", "foo": "bar"} +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.parametrize( "column", [ @@ -173,9 +223,9 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): ], ) def test_roundtrip_partitioned( - tmp_path: pathlib.Path, sample_data: pa.Table, column: str + tmp_path: pathlib.Path, sample_data: pa.Table, column: str, engine ): - write_deltalake(tmp_path, sample_data, partition_by=column) + write_deltalake(tmp_path, sample_data, partition_by=column, engine=engine) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -189,11 +239,16 @@ def test_roundtrip_partitioned( assert add_path.count("/") == 1 -def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_null_partition( + tmp_path: pathlib.Path, sample_data: pa.Table, engine +): sample_data = sample_data.add_column( 0, "utf8_with_nulls", pa.array(["a"] * 4 + [None]) ) - write_deltalake(tmp_path, sample_data, partition_by=["utf8_with_nulls"]) + write_deltalake( + tmp_path, sample_data, partition_by=["utf8_with_nulls"], engine=engine + ) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -203,8 +258,13 @@ def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table) assert table == sample_data -def test_roundtrip_multi_partitioned(tmp_path: pathlib.Path, sample_data: pa.Table): - write_deltalake(tmp_path, sample_data, partition_by=["int32", "bool"]) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_multi_partitioned( + tmp_path: pathlib.Path, sample_data: pa.Table, engine +): + write_deltalake( + tmp_path, sample_data, partition_by=["int32", "bool"], engine=engine + ) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -236,7 +296,25 @@ def test_write_modes(tmp_path: pathlib.Path, sample_data: pa.Table): assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -def test_append_only_should_append_only_with_the_overwrite_mode( +def test_write_modes_rust(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake(tmp_path, sample_data) + assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data + + with pytest.raises(DeltaError): + write_deltalake(tmp_path, sample_data, mode="error", engine="rust") + + write_deltalake(tmp_path, sample_data, mode="ignore", engine="rust") + assert ("0" * 19 + "1.json") not in os.listdir(tmp_path / "_delta_log") + + write_deltalake(tmp_path, sample_data, mode="append", engine="rust") + expected = pa.concat_tables([sample_data, sample_data]) + assert DeltaTable(tmp_path).to_pyarrow_table() == expected + + write_deltalake(tmp_path, sample_data, mode="overwrite", engine="rust") + assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data + + +def test_append_only_should_append_only_with_the_overwrite_mode( # Create rust equivalent rust tmp_path: pathlib.Path, sample_data: pa.Table ): config = {"delta.appendOnly": "true"} @@ -265,8 +343,9 @@ def test_append_only_should_append_only_with_the_overwrite_mode( assert table.version() == 1 -def test_writer_with_table(existing_table: DeltaTable, sample_data: pa.Table): - write_deltalake(existing_table, sample_data, mode="overwrite") +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_with_table(existing_table: DeltaTable, sample_data: pa.Table, engine): + write_deltalake(existing_table, sample_data, mode="overwrite", engine=engine) assert existing_table.to_pyarrow_table() == sample_data @@ -277,9 +356,25 @@ def test_fails_wrong_partitioning(existing_table: DeltaTable, sample_data: pa.Ta ) +def test_fails_wrong_partitioning_rust_writer( + existing_table: DeltaTable, sample_data: pa.Table +): + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + sample_data, + mode="append", + partition_by="int32", + engine="rust", + ) + + +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.pandas @pytest.mark.parametrize("schema_provided", [True, False]) -def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided): +def test_write_pandas( + tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided, engine +): # When timestamp is converted to Pandas, it gets casted to ns resolution, # but Delta Lake schemas only support us resolution. sample_pandas = sample_data.to_pandas() @@ -287,42 +382,52 @@ def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_prov schema = sample_data.schema else: schema = None - write_deltalake(tmp_path, sample_pandas, schema=schema) + write_deltalake(tmp_path, sample_pandas, schema=schema, engine=engine) delta_table = DeltaTable(tmp_path) df = delta_table.to_pandas() assert_frame_equal(df, sample_pandas) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) def test_write_iterator( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, engine ): batches = existing_table.to_pyarrow_dataset().to_batches() with pytest.raises(ValueError): - write_deltalake(tmp_path, batches, mode="overwrite") + write_deltalake(tmp_path, batches, mode="overwrite", engine=engine) - write_deltalake(tmp_path, batches, schema=sample_data.schema, mode="overwrite") + write_deltalake( + tmp_path, batches, schema=sample_data.schema, mode="overwrite", engine=engine + ) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) def test_write_recordbatchreader( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, + existing_table: DeltaTable, + sample_data: pa.Table, + engine: Literal["pyarrow", "rust"], ): batches = existing_table.to_pyarrow_dataset().to_batches() reader = RecordBatchReader.from_batches( existing_table.to_pyarrow_dataset().schema, batches ) - write_deltalake(tmp_path, reader, mode="overwrite") + write_deltalake(tmp_path, reader, mode="overwrite", engine=engine) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -def test_writer_partitioning(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_partitioning( + tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"] +): test_strings = ["a=b", "hello world", "hello%20world"] data = pa.table( {"p": pa.array(test_strings), "x": pa.array(range(len(test_strings)))} ) - write_deltalake(tmp_path, data) + write_deltalake(tmp_path, data, engine=engine) assert DeltaTable(tmp_path).to_pyarrow_table() == data @@ -411,7 +516,8 @@ def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): assert stats["maxValues"] == expected_maxs -def test_writer_null_stats(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_null_stats(tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"]): data = pa.table( { "int32": pa.array([1, None, 2, None], pa.int32()), @@ -419,7 +525,7 @@ def test_writer_null_stats(tmp_path: pathlib.Path): "str": pa.array([None] * 4, pa.string()), } ) - write_deltalake(tmp_path, data) + write_deltalake(tmp_path, data, engine=engine) table = DeltaTable(tmp_path) stats = get_stats(table) @@ -428,10 +534,15 @@ def test_writer_null_stats(tmp_path: pathlib.Path): assert stats["nullCount"] == expected_nulls -def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow"]) # This one is broken +def test_writer_fails_on_protocol( + existing_table: DeltaTable, + sample_data: pa.Table, + engine: Literal["pyarrow", "rust"], +): existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3)) with pytest.raises(DeltaProtocolError): - write_deltalake(existing_table, sample_data, mode="overwrite") + write_deltalake(existing_table, sample_data, mode="overwrite", engine=engine) @pytest.mark.parametrize(