From 2d67292f6e2d925f6b910d980a0092315ea2ee5c Mon Sep 17 00:00:00 2001 From: Nikolay Ulmasov Date: Sat, 18 Nov 2023 14:56:17 +0000 Subject: [PATCH] extend write_deltalake to accept both PyArrow and Deltalake schema Signed-off-by: Nikolay Ulmasov --- python/deltalake/writer.py | 5 ++++- python/tests/test_writer.py | 10 +++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index ef4ae3a57b..da94837e22 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -20,6 +20,7 @@ ) from urllib.parse import unquote +from deltalake import Schema from deltalake.fs import DeltaStorageHandler from ._util import encode_partition_value @@ -73,7 +74,7 @@ def write_deltalake( RecordBatchReader, ], *, - schema: Optional[pa.Schema] = None, + schema: Optional[Union[pa.Schema, Schema]] = None, partition_by: Optional[Union[List[str], str]] = None, filesystem: Optional[pa_fs.FileSystem] = None, mode: Literal["error", "append", "overwrite", "ignore"] = "error", @@ -179,6 +180,8 @@ def write_deltalake( raise ValueError("You must provide schema if data is Iterable") else: schema = data.schema + elif isinstance(schema, Schema): + schema = schema.to_pyarrow() if filesystem is not None: raise NotImplementedError("Filesystem support is not yet implemented. #570") diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index d048f8b79b..2a4df8e1f8 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -16,7 +16,7 @@ from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions from pyarrow.lib import RecordBatchReader -from deltalake import DeltaTable, write_deltalake +from deltalake import DeltaTable, Schema, write_deltalake from deltalake.exceptions import CommitFailedError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -950,3 +950,11 @@ def test_float_values(tmp_path: pathlib.Path): assert actions["min"].field("x2")[0].as_py() is None assert actions["max"].field("x2")[0].as_py() == 1.0 assert actions["null_count"].field("x2")[0].as_py() == 1 + + +def test_with_deltalake_schema(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake( + tmp_path, sample_data, schema=Schema.from_pyarrow(sample_data.schema) + ) + delta_table = DeltaTable(tmp_path) + assert delta_table.schema().to_pyarrow() == sample_data.schema