Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add merge mode to write_delta and remove pyarrow to delta conversions #12392

Merged
merged 9 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 91 additions & 33 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3604,23 +3604,48 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
else:
raise ValueError(f"engine {engine!r} is not supported")

@overload
def write_delta(
self,
target: str | Path | deltalake.DeltaTable,
*,
mode: Literal["error", "append", "overwrite", "ignore"] = ...,
overwrite_schema: bool = ...,
storage_options: dict[str, str] | None = ...,
delta_write_options: dict[str, Any] | None = ...,
) -> None:
...

@overload
def write_delta(
self,
target: str | Path | deltalake.DeltaTable,
*,
mode: Literal["merge"],
overwrite_schema: bool = ...,
storage_options: dict[str, str] | None = ...,
delta_merge_options: dict[str, Any],
) -> deltalake.table.TableMerger:
...

def write_delta(
self,
target: str | Path | deltalake.DeltaTable,
*,
mode: Literal["error", "append", "overwrite", "ignore"] = "error",
mode: Literal["error", "append", "overwrite", "ignore", "merge"] = "error",
overwrite_schema: bool = False,
storage_options: dict[str, str] | None = None,
delta_write_options: dict[str, Any] | None = None,
) -> None:
delta_merge_options: dict[str, Any] | None = None,
) -> deltalake.table.TableMerger | None:
"""
Write DataFrame as delta table.

Parameters
----------
target
URI of a table or a DeltaTable object.
mode : {'error', 'append', 'overwrite', 'ignore'}
mode : {'error', 'append', 'overwrite', 'ignore', 'merge'}
How to handle existing data.

* If 'error', throw an error if the table already exists (default).
Expand All @@ -3638,7 +3663,10 @@ def write_delta(
* See a list of supported storage options for Azure `here <https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html#variants>`__.
delta_write_options
Additional keyword arguments while writing a Delta lake Table.
See a list of supported write options `here <https://github.com/delta-io/delta-rs/blob/395d48b47ea638b70415899dc035cc895b220e55/python/deltalake/writer.py#L65>`__.
See a list of supported write options `here <https://delta-io.github.io/delta-rs/api/delta_writer/#deltalake.write_deltalake>`__.
delta_merge_options
Keyword arguments which are required to `MERGE` a Delta lake Table.
See a list of supported merge options `here <https://delta-io.github.io/delta-rs/api/delta_table/#deltalake.DeltaTable.merge>`__.

Raises
------
Expand All @@ -3647,22 +3675,15 @@ def write_delta(
ArrowInvalidError
If the DataFrame contains data types that could not be cast to their
primitive type.
TableNotFoundError
ion-elgreco marked this conversation as resolved.
Show resolved Hide resolved
If the delta table doesn't exist and MERGE action is triggered

Notes
-----
The Polars data types :class:`Null`, :class:`Categorical` and :class:`Time`
are not supported by the delta protocol specification and will raise a
TypeError.

Some other data types are not supported but have an associated `primitive type
<https://github.com/delta-io/delta/blob/master/PROTOCOL.md#primitive-types>`__
to which they can be cast. This affects the following data types:

- Unsigned integers
- :class:`Datetime` types with millisecond or nanosecond precision or with
time zone information
- :class:`Utf8`, :class:`Binary`, and :class:`List` ('large' types)

Polars columns are always nullable. To write data to a delta table with
non-nullable columns, a custom pyarrow schema has to be passed to the
`delta_write_options`. See the last example below.
Expand Down Expand Up @@ -3719,44 +3740,81 @@ def write_delta(
... },
... ) # doctest: +SKIP

Merging dataframe to the local filesystem as a Delta Lake table. For a
cloud object store just pass storage_options or a DeltaTable object.
For all additional TableMerger methods check the deltalake docs `here <https://delta-io.github.io/delta-rs/api/delta_table/delta_table_merger/>`__.

Schema evolution is currently not yet supported in deltalake, therefore
overwrite_schema won't have any effect during `MERGE`.

>>> df = pl.DataFrame(
... {
... "foo": [1, 2, 3, 4, 5],
... "bar": [6, 7, 8, 9, 10],
... "ham": ["a", "b", "c", "d", "e"],
... }
... )
>>> table_path = "/path/to/delta-table/"
>>> (
... df.write_delta(
... "table_path",
... mode="merge",
... delta_merge_options={
... "predicate": "s.foo = t.foo",
... "source_alias": "s",
... "target_alias": "t",
... },
... )
... .when_matched_update_all()
... .when_not_matched_insert_all()
... .execute()
... ) # doctest: +SKIP
"""
from polars.io.delta import (
_check_for_unsupported_types,
_check_if_delta_available,
_convert_pa_schema_to_delta,
_resolve_delta_lake_uri,
)

_check_if_delta_available()

from deltalake.writer import write_deltalake
from deltalake import DeltaTable, write_deltalake

if delta_write_options is None:
delta_write_options = {}
_check_for_unsupported_types(self.dtypes)

if isinstance(target, (str, Path)):
target = _resolve_delta_lake_uri(str(target), strict=False)

_check_for_unsupported_types(self.dtypes)

data = self.to_arrow()

schema = delta_write_options.pop("schema", None)
if schema is None:
schema = _convert_pa_schema_to_delta(data.schema)
if mode == "merge":
if delta_merge_options is None:
raise ValueError(
"You need to pass delta_merge_options with at least a given predicate for `MERGE` to work."
)
stinodego marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(target, str):
dt = DeltaTable(table_uri=target, storage_options=storage_options)
else:
dt = target

data = data.cast(schema)
return dt.merge(data, **delta_merge_options)

write_deltalake(
table_or_uri=target,
data=data,
schema=schema,
mode=mode,
overwrite_schema=overwrite_schema,
storage_options=storage_options,
large_dtypes=True,
**delta_write_options,
)
else:
if delta_write_options is None:
delta_write_options = {}

schema = delta_write_options.pop("schema", None)
write_deltalake(
table_or_uri=target,
data=data,
schema=schema,
mode=mode,
overwrite_schema=overwrite_schema,
storage_options=storage_options,
large_dtypes=True,
**delta_write_options,
)
return None

def estimated_size(self, unit: SizeUnit = "b") -> int | float:
"""
Expand Down
41 changes: 1 addition & 40 deletions py-polars/polars/io/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from polars.datatypes import Categorical, Null, Time
from polars.datatypes.convert import unpack_dtypes
from polars.dependencies import _DELTALAKE_AVAILABLE, deltalake
from polars.dependencies import pyarrow as pa
from polars.io.pyarrow_dataset import scan_pyarrow_dataset

if TYPE_CHECKING:
Expand Down Expand Up @@ -325,42 +324,4 @@ def _check_for_unsupported_types(dtypes: list[DataType]) -> None:
overlap = schema_dtypes & unsupported_types

if overlap:
raise TypeError(f"DataFrame contains unsupported data types: {overlap!r}")


def _convert_pa_schema_to_delta(schema: pa.schema) -> pa.schema:
"""Convert a PyArrow schema to a schema compatible with Delta Lake."""
# TODO: Add time zone support
dtype_map = {
pa.uint8(): pa.int8(),
pa.uint16(): pa.int16(),
pa.uint32(): pa.int32(),
pa.uint64(): pa.int64(),
}

def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType:
# Handle nested types
if isinstance(dtype, pa.LargeListType):
return list_to_delta_dtype(dtype)
elif isinstance(dtype, pa.StructType):
return struct_to_delta_dtype(dtype)
elif isinstance(dtype, pa.TimestampType):
# TODO: Support time zones when implemented by delta-rs. See:
# https://github.com/delta-io/delta-rs/issues/1598
return pa.timestamp("us")
try:
return dtype_map[dtype]
except KeyError:
return dtype

def list_to_delta_dtype(dtype: pa.LargeListType) -> pa.LargeListType:
nested_dtype = dtype.value_type
nested_dtype_cast = dtype_to_delta_dtype(nested_dtype)
return pa.large_list(nested_dtype_cast)

def struct_to_delta_dtype(dtype: pa.StructType) -> pa.StructType:
fields = [dtype.field(i) for i in range(dtype.num_fields)]
fields_cast = [pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in fields]
return pa.struct(fields_cast)

return pa.schema([pa.field(f.name, dtype_to_delta_dtype(f.type)) for f in schema])
raise TypeError(f"dataframe contains unsupported data types: {overlap!r}")
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Changelog = "https://github.com/pola-rs/polars/releases"
adbc = ["adbc_driver_sqlite"]
cloudpickle = ["cloudpickle"]
connectorx = ["connectorx >= 0.3.2"]
deltalake = ["deltalake >= 0.13.0"]
deltalake = ["deltalake >= 0.14.0"]
fsspec = ["fsspec"]
gevent = ["gevent"]
matplotlib = ["matplotlib"]
Expand Down
2 changes: 1 addition & 1 deletion py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ openpyxl
pyxlsb
xlsx2csv
XlsxWriter
deltalake>=0.13.0
deltalake>=0.14.0
# Dataframe interchange protocol
dataframe-api-compat >= 0.1.6
pyiceberg >= 0.5.0
Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/unit/io/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pyarrow.fs
import pytest
from deltalake import DeltaTable
from deltalake.exceptions import TableNotFoundError
from deltalake.table import TableMerger

import polars as pl
from polars.testing import assert_frame_equal, assert_frame_not_equal
Expand Down Expand Up @@ -369,3 +371,39 @@ def test_write_delta_with_tz_in_df(expr: pl.Expr, tmp_path: Path) -> None:

expected = df.cast(pl.Datetime)
assert_frame_equal(result, expected)


def test_write_delta_with_merge_and_no_table(tmp_path: Path) -> None:
df = pl.DataFrame({"a": [1, 2, 3]})

with pytest.raises(TableNotFoundError):
df.write_delta(
tmp_path, mode="merge", delta_merge_options={"predicate": "a = a"}
)


def test_write_delta_with_merge(tmp_path: Path) -> None:
df = pl.DataFrame({"a": [1, 2, 3]})

df.write_delta(tmp_path, mode="append")

merger = df.write_delta(
tmp_path,
mode="merge",
delta_merge_options={
"predicate": "s.a = t.a",
"source_alias": "s",
"target_alias": "t",
},
)

assert isinstance(merger, TableMerger)
assert merger.predicate == "s.a = t.a"
assert merger.source_alias == "s"
assert merger.target_alias == "t"

merger.when_matched_delete(predicate="t.a > 2").execute()

table = pl.read_delta(str(tmp_path))

assert_frame_equal(df.filter(pl.col("a") <= 2), table)