From 6e9894f974f8807784876e74a4b12d02e002b0d5 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 4 Nov 2023 21:30:49 +0100 Subject: [PATCH] feat(python): allow for multiple `when` calls in MERGE operation (#1750) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description You can now also do multiple when clauses just like in Rust and PySpark. I added one test for now 😄, will add more later when I have some time. I'll update the docs in another PR to reflect the possibility of this behavior. # Related Issue(s) - closes https://github.com/delta-io/delta-rs/issues/1736 --------- Co-authored-by: Will Jones --- python/deltalake/_internal.pyi | 16 +- python/deltalake/table.py | 276 ++++++++++++++++++++++----------- python/src/lib.rs | 150 ++++++++++-------- python/tests/test_merge.py | 231 +++++++++++++++++++++++++++ 4 files changed, 513 insertions(+), 160 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 16d07e144e..85887aeff5 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -104,15 +104,15 @@ class RawDeltaTable: target_alias: Optional[str], writer_properties: Optional[Dict[str, int | None]], safe_cast: bool, - matched_update_updates: Optional[Dict[str, str]], - matched_update_predicate: Optional[str], - matched_delete_predicate: Optional[str], + matched_update_updates: Optional[List[Dict[str, str]]], + matched_update_predicate: Optional[List[Optional[str]]], + matched_delete_predicate: Optional[List[str]], matched_delete_all: Optional[bool], - not_matched_insert_updates: Optional[Dict[str, str]], - not_matched_insert_predicate: Optional[str], - not_matched_by_source_update_updates: Optional[Dict[str, str]], - not_matched_by_source_update_predicate: Optional[str], - not_matched_by_source_delete_predicate: Optional[str], + not_matched_insert_updates: Optional[List[Dict[str, str]]], + not_matched_insert_predicate: Optional[List[Optional[str]]], + not_matched_by_source_update_updates: Optional[List[Dict[str, str]]], + not_matched_by_source_update_predicate: Optional[List[Optional[str]]], + not_matched_by_source_delete_predicate: Optional[List[str]], not_matched_by_source_delete_all: Optional[bool], ) -> str: ... def get_active_partitions( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 61dce5ee0f..ad82a010fd 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -904,7 +904,7 @@ def repair(self, dry_run: bool = False) -> Dict[str, Any]: class TableMerger: - """API for various table MERGE commands.""" + """API for various table `MERGE` commands.""" def __init__( self, @@ -922,15 +922,17 @@ def __init__( self.target_alias = target_alias self.safe_cast = safe_cast self.writer_properties: Optional[Dict[str, Optional[int]]] = None - self.matched_update_updates: Optional[Dict[str, str]] = None - self.matched_update_predicate: Optional[str] = None - self.matched_delete_predicate: Optional[str] = None + self.matched_update_updates: Optional[List[Dict[str, str]]] = None + self.matched_update_predicate: Optional[List[Optional[str]]] = None + self.matched_delete_predicate: Optional[List[str]] = None self.matched_delete_all: Optional[bool] = None - self.not_matched_insert_updates: Optional[Dict[str, str]] = None - self.not_matched_insert_predicate: Optional[str] = None - self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None - self.not_matched_by_source_update_predicate: Optional[str] = None - self.not_matched_by_source_delete_predicate: Optional[str] = None + self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None + self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None + self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None + self.not_matched_by_source_update_predicate: Optional[ + List[Optional[str]] + ] = None + self.not_matched_by_source_delete_predicate: Optional[List[str]] = None self.not_matched_by_source_delete_all: Optional[bool] = None def with_writer_properties( @@ -975,23 +977,32 @@ def when_matched_update( Returns: TableMerger: TableMerger Object - - Examples: - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_matched_update( - ... updates = { - ... "x": "source.x", - ... "y": "source.y" - ... } - ... ).execute() + Examples: + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt.merge( \ + source=data, \ + predicate="target.x = source.x", \ + source_alias="source", \ + target_alias="target") \ + .when_matched_update(updates={"x": "source.x", "y": "source.y"}) \ + .execute() \ + ) + ``` """ - self.matched_update_updates = updates - self.matched_update_predicate = predicate + if isinstance(self.matched_update_updates, list) and isinstance( + self.matched_update_predicate, list + ): + self.matched_update_updates.append(updates) + self.matched_update_predicate.append(predicate) + else: + self.matched_update_updates = [updates] + self.matched_update_predicate = [predicate] return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1006,22 +1017,40 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg Examples: - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_matched_update_all().execute() + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + (\ + dt.merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_matched_update_all() \ + .execute() \ + ) + ``` """ src_alias = (self.source_alias + ".") if self.source_alias is not None else "" trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" - self.matched_update_updates = { + updates = { f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" for col in self.source.schema } - self.matched_update_predicate = predicate + + if isinstance(self.matched_update_updates, list) and isinstance( + self.matched_update_predicate, list + ): + self.matched_update_updates.append(updates) + self.matched_update_predicate.append(predicate) + else: + self.matched_update_updates = [updates] + self.matched_update_predicate = [predicate] + return self def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1037,30 +1066,50 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": Examples: Delete on a predicate - - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_matched_delete(predicate = "source.deleted = true") - ... .execute() - + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt.merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_matched_delete( \ + predicate = "source.deleted = true") \ + .execute() \ + ``` Delete all records that were matched - - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_matched_delete() - ... .execute() + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt.merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_matched_delete() \ + .execute() \ + ``` """ + if self.matched_delete_all is not None: + raise ValueError( + """when_matched_delete without a predicate has already been set, which means + it will delete all, any subsequent when_matched_delete, won't make sense.""" + ) if predicate is None: self.matched_delete_all = True else: - self.matched_delete_predicate = predicate + if isinstance(self.matched_delete_predicate, list): + self.matched_delete_predicate.append(predicate) + else: + self.matched_delete_predicate = [predicate] return self def when_not_matched_insert( @@ -1078,21 +1127,35 @@ def when_not_matched_insert( Examples: - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_not_matched_insert( - ... updates = { - ... "x": "source.x", - ... "y": "source.y" - ... } - ... ).execute() + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt.merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_not_matched_insert( \ + updates = { \ + "x": "source.x", \ + "y": "source.y", \ + }) \ + .execute() \ + ) + ``` """ - self.not_matched_insert_updates = updates - self.not_matched_insert_predicate = predicate + if isinstance(self.not_matched_insert_updates, list) and isinstance( + self.not_matched_insert_predicate, list + ): + self.not_matched_insert_updates.append(updates) + self.not_matched_insert_predicate.append(predicate) + else: + self.not_matched_insert_updates = [updates] + self.not_matched_insert_predicate = [predicate] return self @@ -1111,21 +1174,39 @@ def when_not_matched_insert_all( Examples: - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_not_matched_insert_all().execute() + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt \ + .merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_not_matched_insert_all() \ + .execute() \ + ) + ``` """ src_alias = (self.source_alias + ".") if self.source_alias is not None else "" trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" - self.not_matched_insert_updates = { + updates = { f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" for col in self.source.schema } - self.not_matched_insert_predicate = predicate + if isinstance(self.not_matched_insert_updates, list) and isinstance( + self.not_matched_insert_predicate, list + ): + self.not_matched_insert_updates.append(updates) + self.not_matched_insert_predicate.append(predicate) + else: + self.not_matched_insert_updates = [updates] + self.not_matched_insert_predicate = [predicate] + return self def when_not_matched_by_source_update( @@ -1140,21 +1221,34 @@ def when_not_matched_by_source_update( Returns: TableMerger: TableMerger Object - - >>> from deltalake import DeltaTable - >>> import pyarrow as pa - >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) - >>> dt = DeltaTable("tmp") - >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ - ... .when_not_matched_by_source_update( - ... predicate = "y > 3" - ... updates = { - ... "y": "0", - ... } - ... ).execute() + + ``` + from deltalake import DeltaTable + import pyarrow as pa + data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + dt = DeltaTable("tmp") + ( \ + dt.merge( \ + source=data, \ + predicate='target.x = source.x', \ + source_alias='source', \ + target_alias='target') \ + .when_not_matched_by_source_update( \ + predicate = "y > 3", \ + updates = {"y": "0"}) \ + .execute() \ + ) \ + ``` """ - self.not_matched_by_source_update_updates = updates - self.not_matched_by_source_update_predicate = predicate + + if isinstance(self.not_matched_by_source_update_updates, list) and isinstance( + self.not_matched_by_source_update_predicate, list + ): + self.not_matched_by_source_update_updates.append(updates) + self.not_matched_by_source_update_predicate.append(predicate) + else: + self.not_matched_by_source_update_updates = [updates] + self.not_matched_by_source_update_predicate = [predicate] return self def when_not_matched_by_source_delete( @@ -1169,15 +1263,23 @@ def when_not_matched_by_source_delete( Returns: TableMerger: TableMerger Object """ + if self.not_matched_by_source_delete_all is not None: + raise ValueError( + """when_not_matched_by_source_delete without a predicate has already been set, which means + it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense.""" + ) if predicate is None: self.not_matched_by_source_delete_all = True else: - self.not_matched_by_source_delete_predicate = predicate + if isinstance(self.not_matched_by_source_delete_predicate, list): + self.not_matched_by_source_delete_predicate.append(predicate) + else: + self.not_matched_by_source_delete_predicate = [predicate] return self def execute(self) -> Dict[str, Any]: - """Executes MERGE with the previously provided settings in Rust with Apache Datafusion query engine. + """Executes `MERGE` with the previously provided settings in Rust with Apache Datafusion query engine. Returns: Dict[str, Any]: metrics diff --git a/python/src/lib.rs b/python/src/lib.rs index 5f2fd8a11f..93f71597ba 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -427,15 +427,15 @@ impl RawDeltaTable { target_alias: Option, safe_cast: bool, writer_properties: Option>, - matched_update_updates: Option>, - matched_update_predicate: Option, - matched_delete_predicate: Option, + matched_update_updates: Option>>, + matched_update_predicate: Option>>, + matched_delete_predicate: Option>, matched_delete_all: Option, - not_matched_insert_updates: Option>, - not_matched_insert_predicate: Option, - not_matched_by_source_update_updates: Option>, - not_matched_by_source_update_predicate: Option, - not_matched_by_source_delete_predicate: Option, + not_matched_insert_updates: Option>>, + not_matched_insert_predicate: Option>>, + not_matched_by_source_update_updates: Option>>, + not_matched_by_source_update_predicate: Option>>, + not_matched_by_source_delete_predicate: Option>, not_matched_by_source_delete_all: Option, ) -> PyResult { let ctx = SessionContext::new(); @@ -489,23 +489,29 @@ impl RawDeltaTable { if let Some(mu_updates) = matched_update_updates { if let Some(mu_predicate) = matched_update_predicate { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in mu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update.predicate(mu_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in mu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; + for it in mu_updates.iter().zip(mu_predicate.iter()) { + let (update_values, predicate_value) = it; + + if let Some(pred) = predicate_value { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } } } @@ -514,52 +520,64 @@ impl RawDeltaTable { .when_matched_delete(|delete| delete) .map_err(PythonError::from)?; } else if let Some(md_predicate) = matched_delete_predicate { - cmd = cmd - .when_matched_delete(|delete| delete.predicate(md_predicate)) - .map_err(PythonError::from)?; + for pred in md_predicate.iter() { + cmd = cmd + .when_matched_delete(|delete| delete.predicate(pred.clone())) + .map_err(PythonError::from)?; + } } if let Some(nmi_updates) = not_matched_insert_updates { if let Some(nmi_predicate) = not_matched_insert_predicate { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in nmi_updates { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert.predicate(nmi_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in nmi_updates { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert - }) - .map_err(PythonError::from)?; + for it in nmi_updates.iter().zip(nmi_predicate.iter()) { + let (update_values, predicate_value) = it; + if let Some(pred) = predicate_value { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in update_values { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in update_values { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert + }) + .map_err(PythonError::from)?; + } + } } } if let Some(nmbsu_updates) = not_matched_by_source_update_updates { if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in nmbsu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update.predicate(nmbsu_predicate) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in nmbsu_updates { - update = update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; + for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) { + let (update_values, predicate_value) = it; + if let Some(pred) = predicate_value { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(pred.clone()) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in update_values { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } } } @@ -568,9 +586,11 @@ impl RawDeltaTable { .when_not_matched_by_source_delete(|delete| delete) .map_err(PythonError::from)?; } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate)) - .map_err(PythonError::from)?; + for pred in nmbs_predicate.iter() { + cmd = cmd + .when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone())) + .map_err(PythonError::from)?; + } } let (table, metrics) = rt()? diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index fc08563443..ddc5a34ea1 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -490,3 +490,234 @@ def test_merge_when_not_matched_by_source_delete_wo_predicate( assert last_action["operation"] == "MERGE" assert result == expected + + +def test_merge_multiple_when_matched_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, True]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = False", + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = True", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_multiple_when_matched_update_all_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, True]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_matched_update_all( + predicate="source.deleted = False", + ).when_matched_update_all( + predicate="source.deleted = True", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False, False, False, False, True]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_multiple_when_not_matched_insert_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "9"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + }, + predicate="source.price < bigint'50'", + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + }, + predicate="source.price > bigint'50'", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6", "9"]), + "price": pa.array([0, 1, 2, 3, 4, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10, 20], pa.int32()), + "deleted": pa.array([False] * 7), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_multiple_when_matched_delete_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["5", "4"]), + "weight": pa.array([1, 2], pa.int64()), + "sold": pa.array([1, 2], pa.int32()), + "deleted": pa.array([True, False]), + "customer": pa.array(["Adam", "Patrick"]), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_delete("s.deleted = True").when_matched_delete( + "s.deleted = false" + ).execute() + + nrows = 3 + expected = pa.table( + { + "id": pa.array(["1", "2", "3"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_multiple_when_not_matched_by_source_update_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + """The first match clause that meets the predicate will be executed, so if you do an update + without an other predicate the first clause will be matched and therefore the other ones are skipped. + """ + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_update( + updates={ + "sold": "int'10'", + } + ).when_not_matched_by_source_update( + updates={ + "sold": "int'100'", + } + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([10, 10, 10, 10, 10], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected