Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python): add pymergebuilder #2823

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class RawDeltaTable:
custom_metadata: Optional[Dict[str, str]],
post_commithook_properties: Optional[Dict[str, Optional[bool]]],
) -> str: ...
def merge_execute(
def create_merge_builder(
self,
source: pyarrow.RecordBatchReader,
predicate: str,
Expand All @@ -153,17 +153,8 @@ class RawDeltaTable:
custom_metadata: Optional[Dict[str, str]],
post_commithook_properties: Optional[Dict[str, Optional[bool]]],
safe_cast: bool,
matched_update_updates: Optional[List[Dict[str, str]]],
matched_update_predicate: Optional[List[Optional[str]]],
matched_delete_predicate: Optional[List[str]],
matched_delete_all: Optional[bool],
not_matched_insert_updates: Optional[List[Dict[str, str]]],
not_matched_insert_predicate: Optional[List[Optional[str]]],
not_matched_by_source_update_updates: Optional[List[Dict[str, str]]],
not_matched_by_source_update_predicate: Optional[List[Optional[str]]],
not_matched_by_source_delete_predicate: Optional[List[str]],
not_matched_by_source_delete_all: Optional[bool],
) -> str: ...
) -> PyMergeBuilder: ...
def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ...
def get_active_partitions(
self, partitions_filters: Optional[FilterType] = None
) -> Any: ...
Expand Down Expand Up @@ -244,6 +235,26 @@ def get_num_idx_cols_and_stats_columns(
table: Optional[RawDeltaTable], configuration: Optional[Mapping[str, Optional[str]]]
) -> Tuple[int, Optional[List[str]]]: ...

class PyMergeBuilder:
source_alias: str
target_alias: str
arrow_schema: pyarrow.Schema

def when_matched_update(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_matched_delete(self, predicate: Optional[str]) -> None: ...
def when_not_matched_insert(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_not_matched_by_source_update(
self, updates: Dict[str, str], predicate: Optional[str]
) -> None: ...
def when_not_matched_by_source_delete(
self,
predicate: Optional[str],
) -> None: ...

# Can't implement inheritance (see note in src/schema.rs), so this is next
# best thing.
DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"]
Expand Down
166 changes: 32 additions & 134 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import os

from deltalake._internal import (
PyMergeBuilder,
RawDeltaTable,
)
from deltalake._internal import create_deltalake as _create_deltalake
Expand Down Expand Up @@ -952,17 +953,19 @@ def merge(
source.schema, (batch for batch in source)
)

return TableMerger(
self,
py_merge_builder = self._table.create_merge_builder(
source=source,
predicate=predicate,
source_alias=source_alias,
target_alias=target_alias,
safe_cast=not error_on_type_mismatch,
writer_properties=writer_properties,
custom_metadata=custom_metadata,
post_commithook_properties=post_commithook_properties,
post_commithook_properties=post_commithook_properties.__dict__
if post_commithook_properties
else None,
)
return TableMerger(py_merge_builder, self._table)

def restore(
self,
Expand Down Expand Up @@ -1295,37 +1298,11 @@ class TableMerger:

def __init__(
self,
table: DeltaTable,
source: pyarrow.RecordBatchReader,
predicate: str,
source_alias: Optional[str] = None,
target_alias: Optional[str] = None,
safe_cast: bool = True,
writer_properties: Optional[WriterProperties] = None,
custom_metadata: Optional[Dict[str, str]] = None,
post_commithook_properties: Optional[PostCommitHookProperties] = None,
builder: PyMergeBuilder,
table: RawDeltaTable,
):
self.table = table
self.source = source
self.predicate = predicate
self.source_alias = source_alias
self.target_alias = target_alias
self.safe_cast = safe_cast
self.writer_properties = writer_properties
self.custom_metadata = custom_metadata
self.post_commithook_properties = post_commithook_properties
self.matched_update_updates: Optional[List[Dict[str, str]]] = None
self.matched_update_predicate: Optional[List[Optional[str]]] = None
self.matched_delete_predicate: Optional[List[str]] = None
self.matched_delete_all: Optional[bool] = None
self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None
self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None
self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = (
None
)
self.not_matched_by_source_delete_predicate: Optional[List[str]] = None
self.not_matched_by_source_delete_all: Optional[bool] = None
self._builder = builder
self._table = table

def when_matched_update(
self, updates: Dict[str, str], predicate: Optional[str] = None
Expand Down Expand Up @@ -1372,14 +1349,7 @@ def when_matched_update(
2 3 6
```
"""
if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]
self._builder.when_matched_update(updates, predicate)
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -1424,24 +1394,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
2 3 6
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else ""
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
for col in self._builder.arrow_schema
}

if isinstance(self.matched_update_updates, list) and isinstance(
self.matched_update_predicate, list
):
self.matched_update_updates.append(updates)
self.matched_update_predicate.append(predicate)
else:
self.matched_update_updates = [updates]
self.matched_update_predicate = [predicate]

self._builder.when_matched_update(updates, predicate)
return self

def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
Expand Down Expand Up @@ -1507,19 +1473,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger":
0 1 4
```
"""
if self.matched_delete_all is not None:
raise ValueError(
"""when_matched_delete without a predicate has already been set, which means
it will delete all, any subsequent when_matched_delete, won't make sense."""
)

if predicate is None:
self.matched_delete_all = True
else:
if isinstance(self.matched_delete_predicate, list):
self.matched_delete_predicate.append(predicate)
else:
self.matched_delete_predicate = [predicate]
self._builder.when_matched_delete(predicate)
return self

def when_not_matched_insert(
Expand Down Expand Up @@ -1572,16 +1526,7 @@ def when_not_matched_insert(
3 4 7
```
"""

if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

self._builder.when_not_matched_insert(updates, predicate)
return self

def when_not_matched_insert_all(
Expand Down Expand Up @@ -1630,22 +1575,19 @@ def when_not_matched_insert_all(
3 4 7
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias

src_alias = (self.source_alias + ".") if self.source_alias is not None else ""
trgt_alias = (self.target_alias + ".") if self.target_alias is not None else ""
src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else ""
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)
updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self.source.schema
for col in self._builder.arrow_schema
}
if isinstance(self.not_matched_insert_updates, list) and isinstance(
self.not_matched_insert_predicate, list
):
self.not_matched_insert_updates.append(updates)
self.not_matched_insert_predicate.append(predicate)
else:
self.not_matched_insert_updates = [updates]
self.not_matched_insert_predicate = [predicate]

self._builder.when_not_matched_insert(updates, predicate)
return self

def when_not_matched_by_source_update(
Expand Down Expand Up @@ -1695,15 +1637,7 @@ def when_not_matched_by_source_update(
2 3 6
```
"""

if isinstance(self.not_matched_by_source_update_updates, list) and isinstance(
self.not_matched_by_source_update_predicate, list
):
self.not_matched_by_source_update_updates.append(updates)
self.not_matched_by_source_update_predicate.append(predicate)
else:
self.not_matched_by_source_update_updates = [updates]
self.not_matched_by_source_update_predicate = [predicate]
self._builder.when_not_matched_by_source_update(updates, predicate)
return self

def when_not_matched_by_source_delete(
Expand All @@ -1722,19 +1656,7 @@ def when_not_matched_by_source_delete(
Returns:
TableMerger: TableMerger Object
"""
if self.not_matched_by_source_delete_all is not None:
raise ValueError(
"""when_not_matched_by_source_delete without a predicate has already been set, which means
it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense."""
)

if predicate is None:
self.not_matched_by_source_delete_all = True
else:
if isinstance(self.not_matched_by_source_delete_predicate, list):
self.not_matched_by_source_delete_predicate.append(predicate)
else:
self.not_matched_by_source_delete_predicate = [predicate]
self._builder.when_not_matched_by_source_delete(predicate)
return self

def execute(self) -> Dict[str, Any]:
Expand All @@ -1743,31 +1665,7 @@ def execute(self) -> Dict[str, Any]:
Returns:
Dict: metrics
"""
metrics = self.table._table.merge_execute(
source=self.source,
predicate=self.predicate,
source_alias=self.source_alias,
target_alias=self.target_alias,
safe_cast=self.safe_cast,
writer_properties=self.writer_properties
if self.writer_properties
else None,
custom_metadata=self.custom_metadata,
post_commithook_properties=self.post_commithook_properties.__dict__
if self.post_commithook_properties
else None,
matched_update_updates=self.matched_update_updates,
matched_update_predicate=self.matched_update_predicate,
matched_delete_predicate=self.matched_delete_predicate,
matched_delete_all=self.matched_delete_all,
not_matched_insert_updates=self.not_matched_insert_updates,
not_matched_insert_predicate=self.not_matched_insert_predicate,
not_matched_by_source_update_updates=self.not_matched_by_source_update_updates,
not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate,
not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate,
not_matched_by_source_delete_all=self.not_matched_by_source_delete_all,
)
self.table.update_incremental()
metrics = self._table.merge_execute(self._builder)
return json.loads(metrics)


Expand Down
Loading
Loading