From 8ff124db80373b0e22a0e6bbba9ae866ccfebeae Mon Sep 17 00:00:00 2001 From: Nikolay Ulmasov Date: Sat, 18 Nov 2023 14:56:17 +0000 Subject: [PATCH 1/3] extend write_deltalake to accept both PyArrow and Deltalake schema Signed-off-by: Nikolay Ulmasov --- python/deltalake/writer.py | 3 ++- python/tests/test_writer.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 626fb1a5d9..6a20da69ca 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -21,6 +21,7 @@ ) from urllib.parse import unquote +from deltalake import Schema from deltalake.fs import DeltaStorageHandler from ._util import encode_partition_value @@ -142,7 +143,7 @@ def write_deltalake( RecordBatchReader, ], *, - schema: Optional[pa.Schema] = None, + schema: Optional[Union[pa.Schema, Schema]] = None, partition_by: Optional[Union[List[str], str]] = None, filesystem: Optional[pa_fs.FileSystem] = None, mode: Literal["error", "append", "overwrite", "ignore"] = "error", diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 0a63b16c70..b0177149e8 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -16,8 +16,8 @@ from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions from pyarrow.lib import RecordBatchReader -from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError +from deltalake import DeltaTable, Schema, write_deltalake +from deltalake.exceptions import CommitFailedError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -1176,3 +1176,11 @@ def test_float_values(tmp_path: pathlib.Path): assert actions["min"].field("x2")[0].as_py() is None assert actions["max"].field("x2")[0].as_py() == 1.0 assert actions["null_count"].field("x2")[0].as_py() == 1 + + +def test_with_deltalake_schema(tmp_path: pathlib.Path, sample_data: pa.Table): + write_deltalake( + tmp_path, sample_data, schema=Schema.from_pyarrow(sample_data.schema) + ) + delta_table = DeltaTable(tmp_path) + assert delta_table.schema().to_pyarrow() == sample_data.schema From c0635c58f5055252c6805b260e7010ff8b83b319 Mon Sep 17 00:00:00 2001 From: Nikolay Ulmasov Date: Sat, 18 Nov 2023 15:52:01 +0000 Subject: [PATCH 2/3] fix type check error Signed-off-by: Nikolay Ulmasov --- python/deltalake/writer.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 6a20da69ca..70b3936028 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -336,21 +336,24 @@ def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: except KeyError: return dtype - if partition_by: - if PYARROW_MAJOR_VERSION < 12: - partition_schema = pa.schema( - [ - pa.field(name, _large_to_normal_dtype(schema.field(name).type)) - for name in partition_by - ] - ) - else: - partition_schema = pa.schema( - [schema.field(name) for name in partition_by] - ) - partitioning = ds.partitioning(partition_schema, flavor="hive") + if partition_by: + table_schema: pa.Schema = schema + if PYARROW_MAJOR_VERSION < 12: + partition_schema = pa.schema( + [ + pa.field( + name, _large_to_normal_dtype(table_schema.field(name).type) + ) + for name in partition_by + ] + ) else: - partitioning = None + partition_schema = pa.schema( + [table_schema.field(name) for name in partition_by] + ) + partitioning = ds.partitioning(partition_schema, flavor="hive") + else: + partitioning = None add_actions: List[AddAction] = [] From cde9887791e470a01c1a5b4b59da403905c83f22 Mon Sep 17 00:00:00 2001 From: Nikolay Ulmasov Date: Wed, 29 Nov 2023 19:15:11 +0000 Subject: [PATCH 3/3] rebase to latest changes to write_deltalake Signed-off-by: Nikolay Ulmasov --- python/deltalake/writer.py | 37 ++++++++++++++++++++----------------- python/tests/test_writer.py | 2 +- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 70b3936028..aeabf806d2 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -245,6 +245,9 @@ def write_deltalake( if isinstance(partition_by, str): partition_by = [partition_by] + if isinstance(schema, Schema): + schema = schema.to_pyarrow() + if isinstance(data, RecordBatchReader): data = convert_pyarrow_recordbatchreader(data, large_dtypes) elif isinstance(data, pa.RecordBatch): @@ -336,24 +339,24 @@ def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: except KeyError: return dtype - if partition_by: - table_schema: pa.Schema = schema - if PYARROW_MAJOR_VERSION < 12: - partition_schema = pa.schema( - [ - pa.field( - name, _large_to_normal_dtype(table_schema.field(name).type) - ) - for name in partition_by - ] - ) + if partition_by: + table_schema: pa.Schema = schema + if PYARROW_MAJOR_VERSION < 12: + partition_schema = pa.schema( + [ + pa.field( + name, _large_to_normal_dtype(table_schema.field(name).type) + ) + for name in partition_by + ] + ) + else: + partition_schema = pa.schema( + [table_schema.field(name) for name in partition_by] + ) + partitioning = ds.partitioning(partition_schema, flavor="hive") else: - partition_schema = pa.schema( - [table_schema.field(name) for name in partition_by] - ) - partitioning = ds.partitioning(partition_schema, flavor="hive") - else: - partitioning = None + partitioning = None add_actions: List[AddAction] = [] diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index b0177149e8..49177782ff 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -17,7 +17,7 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, Schema, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaProtocolError +from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri