Skip to content

Commit

Permalink
refactor: python merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Aug 24, 2024
1 parent 46b38d2 commit f776d71
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 320 deletions.
35 changes: 23 additions & 12 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class RawDeltaTable:
custom_metadata: Optional[Dict[str, str]],
post_commithook_properties: Optional[Dict[str, Optional[bool]]],
) -> str: ...
def merge_execute(
def create_merge_builder(
self,
source: pyarrow.RecordBatchReader,
predicate: str,
Expand All @@ -153,17 +153,8 @@ class RawDeltaTable:
custom_metadata: Optional[Dict[str, str]],
post_commithook_properties: Optional[Dict[str, Optional[bool]]],
safe_cast: bool,
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
) -> PyMergeBuilder: ...
def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ...
def get_active_partitions(
self, partitions_filters: Optional[FilterType] = None
) -> Any: ...
Expand Down Expand Up @@ -244,6 +235,26 @@ def get_num_idx_cols_and_stats_columns(
table: Optional[RawDeltaTable], configuration: Optional[Mapping[str, Optional[str]]]
) -> Tuple[int, Optional[List[str]]]: ...

class PyMergeBuilder:
source_alias: str
target_alias: str
arrow_schema: pyarrow.Schema

def when_matched_update(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_matched_delete(self, predicate: Optional[str]) -> None: ...
def when_not_matched_insert(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_not_matched_by_source_update(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_not_matched_by_source_delete(
self,
predicate: Optional[str],
) -> None: ...

# Can't implement inheritance (see note in src/schema.rs), so this is next
# best thing.
DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"]
Expand Down
166 changes: 32 additions & 134 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import os

from deltalake._internal import (
PyMergeBuilder,
RawDeltaTable,
)
from deltalake._internal import create_deltalake as _create_deltalake
Expand Down Expand Up @@ -952,17 +953,19 @@ def merge(
source.schema, (batch for batch in source)
)

return TableMerger(
self,
py_merge_builder = self._table.create_merge_builder(
source=source,
predicate=predicate,
source_alias=source_alias,
target_alias=target_alias,
safe_cast=not error_on_type_mismatch,
writer_properties=writer_properties,
custom_metadata=custom_metadata,
post_commithook_properties=post_commithook_properties,
post_commithook_properties=post_commithook_properties.__dict__
if post_commithook_properties
else None,
)
return TableMerger(py_merge_builder, self._table)

def restore(
self,
Expand Down Expand Up @@ -1295,37 +1298,11 @@ class TableMerger:

def __init__(
self,
table: DeltaTable,
source: pyarrow.RecordBatchReader,
predicate: str,
source_alias: Optional[str] = None,
target_alias: Optional[str] = None,
safe_cast: bool = True,
writer_properties: Optional[WriterProperties] = None,
custom_metadata: Optional[Dict[str, str]] = None,
post_commithook_properties: Optional[PostCommitHookProperties] = None,
builder: PyMergeBuilder,
table: RawDeltaTable,
):
self.table = table
self.source = source
self.predicate = predicate
self.source_alias = source_alias
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties = writer_properties
self.custom_metadata = custom_metadata
self.post_commithook_properties = post_commithook_properties
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = (
None
)
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None
self._builder = builder
self._table = table

def when_matched_update(
self, updates: Dict[str, str], predicate: Optional[str] = None
Expand Down Expand Up @@ -1372,14 +1349,7 @@ def when_matched_update(
2 3 6
```
"""
if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
self._builder.when_matched_update(updates, predicate)
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -1424,24 +1394,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
2 3 6
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else ""
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
for col in self._builder.arrow_schema
}

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

self._builder.when_matched_update(updates, predicate)
return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -1507,19 +1473,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
0 1 4
```
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
self._builder.when_matched_delete(predicate)
return self

def when_not_matched_insert(
Expand Down Expand Up @@ -1572,16 +1526,7 @@ def when_not_matched_insert(
3 4 7
```
"""

if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

self._builder.when_not_matched_insert(updates, predicate)
return self

def when_not_matched_insert_all(
Expand Down Expand Up @@ -1630,22 +1575,19 @@ def when_not_matched_insert_all(
3 4 7
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else ""
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)
updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
for col in self._builder.arrow_schema
}
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

self._builder.when_not_matched_insert(updates, predicate)
return self

def when_not_matched_by_source_update(
Expand Down Expand Up @@ -1695,15 +1637,7 @@ def when_not_matched_by_source_update(
2 3 6
```
"""

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
self._builder.when_not_matched_by_source_update(updates, predicate)
return self

def when_not_matched_by_source_delete(
Expand All @@ -1722,19 +1656,7 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
self._builder.when_not_matched_by_source_delete(predicate)
return self

def execute(self) -> Dict[str, Any]:
Expand All @@ -1743,31 +1665,7 @@ def execute(self) -> Dict[str, Any]:
Returns:
Dict: metrics
"""
metrics = self.table._table.merge_execute(
source=self.source,
predicate=self.predicate,
source_alias=self.source_alias,
target_alias=self.target_alias,
safe_cast=self.safe_cast,
writer_properties=self.writer_properties
if self.writer_properties
else None,
custom_metadata=self.custom_metadata,
post_commithook_properties=self.post_commithook_properties.__dict__
if self.post_commithook_properties
else None,
matched_update_updates=self.matched_update_updates,
matched_update_predicate=self.matched_update_predicate,
matched_delete_predicate=self.matched_delete_predicate,
matched_delete_all=self.matched_delete_all,
not_matched_insert_updates=self.not_matched_insert_updates,
not_matched_insert_predicate=self.not_matched_insert_predicate,
not_matched_by_source_update_updates=self.not_matched_by_source_update_updates,
not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate,
not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate,
not_matched_by_source_delete_all=self.not_matched_by_source_delete_all,
)
self.table.update_incremental()
metrics = self._table.merge_execute(self._builder)
return json.loads(metrics)


Expand Down
Loading

0 comments on commit f776d71

Please sign in to comment.