Skip to content

Commit

Permalink
feat(python): allow for multiple when calls in MERGE operation (#1750)
Browse files Browse the repository at this point in the history
# Description
You can now also do multiple when clauses just like in Rust and PySpark.
I added one test for now 😄, will add more later when I have some time.

I'll update the docs in another PR to reflect the possibility of this
behavior.

# Related Issue(s)
<!---
For example:

- closes #106
--->
- closes #1736

---------

Co-authored-by: Will Jones <[email protected]>
  • Loading branch information
ion-elgreco and wjones127 authored Nov 4, 2023
1 parent 5a5dbcd commit 6e9894f
Show file tree
Hide file tree
Showing 4 changed files with 513 additions and 160 deletions.
16 changes: 8 additions & 8 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ class RawDeltaTable:
target_alias: Optional[str],
writer_properties: Optional[Dict[str, int | None]],
safe_cast: bool,
matched_update_updates: Optional[Dict[str, str]],
matched_update_predicate: Optional[str],
matched_delete_predicate: Optional[str],
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[Dict[str, str]],
not_matched_insert_predicate: Optional[str],
not_matched_by_source_update_updates: Optional[Dict[str, str]],
not_matched_by_source_update_predicate: Optional[str],
not_matched_by_source_delete_predicate: Optional[str],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
def get_active_partitions(
Expand Down
276 changes: 189 additions & 87 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def repair(self, dry_run: bool = False) -> Dict[str, Any]:


class TableMerger:
"""API for various table MERGE commands."""
"""API for various table `MERGE` commands."""

def __init__(
self,
Expand All @@ -922,15 +922,17 @@ def __init__(
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties: Optional[Dict[str, Optional[int]]] = None
self.matched_update_updates: Optional[Dict[str, str]] = None
self.matched_update_predicate: Optional[str] = None
self.matched_delete_predicate: Optional[str] = None
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[Dict[str, str]] = None
self.not_matched_insert_predicate: Optional[str] = None
self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None
self.not_matched_by_source_update_predicate: Optional[str] = None
self.not_matched_by_source_delete_predicate: Optional[str] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[
List[Optional[str]]
] = None
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None

def with_writer_properties(
Expand Down Expand Up @@ -975,23 +977,32 @@ def when_matched_update(
Returns:
TableMerger: TableMerger Object
Examples:
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_update(
... updates = {
... "x": "source.x",
... "y": "source.y"
... }
... ).execute()
Examples:
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate="target.x = source.x", \
source_alias="source", \
target_alias="target") \
.when_matched_update(updates={"x": "source.x", "y": "source.y"}) \
.execute() \
)
```
"""
self.matched_update_updates = updates
self.matched_update_predicate = predicate
if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -1006,22 +1017,40 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
Examples:
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_update_all().execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
(\
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_update_all() \
.execute() \
)
```
"""

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""

self.matched_update_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.matched_update_predicate = predicate

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand All @@ -1037,30 +1066,50 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Examples:
Delete on a predicate
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_delete(predicate = "source.deleted = true")
... .execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_delete( \
predicate = "source.deleted = true") \
.execute() \
```
Delete all records that were matched
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_matched_delete()
... .execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_matched_delete() \
.execute() \
```
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
self.matched_delete_predicate = predicate
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
return self

def when_not_matched_insert(
Expand All @@ -1078,21 +1127,35 @@ def when_not_matched_insert(
Examples:
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_insert(
... updates = {
... "x": "source.x",
... "y": "source.y"
... }
... ).execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_insert( \
updates = { \
"x": "source.x", \
"y": "source.y", \
}) \
.execute() \
)
```
"""

self.not_matched_insert_updates = updates
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

Expand All @@ -1111,21 +1174,39 @@ def when_not_matched_insert_all(
Examples:
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_insert_all().execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt \
.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_insert_all() \
.execute() \
)
```
"""

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
self.not_matched_insert_updates = {
updates = {
f"{trgt_alias}{col.name}": f"{src_alias}{col.name}"
for col in self.source.schema
}
self.not_matched_insert_predicate = predicate
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

return self

def when_not_matched_by_source_update(
Expand All @@ -1140,21 +1221,34 @@ def when_not_matched_by_source_update(
Returns:
TableMerger: TableMerger Object
>>> from deltalake import DeltaTable
>>> import pyarrow as pa
>>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> dt = DeltaTable("tmp")
>>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \
... .when_not_matched_by_source_update(
... predicate = "y > 3"
... updates = {
... "y": "0",
... }
... ).execute()
```
from deltalake import DeltaTable
import pyarrow as pa
data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]})
dt = DeltaTable("tmp")
( \
dt.merge( \
source=data, \
predicate='target.x = source.x', \
source_alias='source', \
target_alias='target') \
.when_not_matched_by_source_update( \
predicate = "y > 3", \
updates = {"y": "0"}) \
.execute() \
) \
```
"""
self.not_matched_by_source_update_updates = updates
self.not_matched_by_source_update_predicate = predicate

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
return self

def when_not_matched_by_source_delete(
Expand All @@ -1169,15 +1263,23 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
self.not_matched_by_source_delete_predicate = predicate
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
return self

def execute(self) -> Dict[str, Any]:
"""Executes MERGE with the previously provided settings in Rust with Apache Datafusion query engine.
"""Executes `MERGE` with the previously provided settings in Rust with Apache Datafusion query engine.
Returns:
Dict[str, Any]: metrics
Expand Down
Loading

0 comments on commit 6e9894f

Please sign in to comment.