From 157bd3d8ce53d78194883d2c96e2b2da24862a88 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Fri, 23 Aug 2024 18:07:55 +0200 Subject: [PATCH] refactor: python merge --- python/deltalake/_internal.pyi | 35 ++++-- python/deltalake/table.py | 166 +++++--------------------- python/src/lib.rs | 208 ++++++-------------------------- python/src/merge.rs | 210 +++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 320 deletions(-) create mode 100644 python/src/merge.rs diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 440f7a9eec..1e50d9bdf8 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -143,7 +143,7 @@ class RawDeltaTable: custom_metadata: Optional[Dict[str, str]], post_commithook_properties: Optional[Dict[str, Optional[bool]]], ) -> str: ... - def merge_execute( + def create_merge_builder( self, source: pyarrow.RecordBatchReader, predicate: str, @@ -153,17 +153,8 @@ class RawDeltaTable: custom_metadata: Optional[Dict[str, str]], post_commithook_properties: Optional[Dict[str, Optional[bool]]], safe_cast: bool, - matched_update_updates: Optional[List[Dict[str, str]]], - matched_update_predicate: Optional[List[Optional[str]]], - matched_delete_predicate: Optional[List[str]], - matched_delete_all: Optional[bool], - not_matched_insert_updates: Optional[List[Dict[str, str]]], - not_matched_insert_predicate: Optional[List[Optional[str]]], - not_matched_by_source_update_updates: Optional[List[Dict[str, str]]], - not_matched_by_source_update_predicate: Optional[List[Optional[str]]], - not_matched_by_source_delete_predicate: Optional[List[str]], - not_matched_by_source_delete_all: Optional[bool], - ) -> str: ... + ) -> PyMergeBuilder: ... + def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ... def get_active_partitions( self, partitions_filters: Optional[FilterType] = None ) -> Any: ... @@ -244,6 +235,26 @@ def get_num_idx_cols_and_stats_columns( table: Optional[RawDeltaTable], configuration: Optional[Mapping[str, Optional[str]]] ) -> Tuple[int, Optional[List[str]]]: ... +class PyMergeBuilder: + source_alias: str + target_alias: str + arrow_schema: pyarrow.Schema + + def when_matched_update( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_matched_delete(self, predicate: Optional[str]) -> None: ... + def when_not_matched_insert( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_not_matched_by_source_update( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_not_matched_by_source_delete( + self, + predicate: Optional[str], + ) -> None: ... + # Can't implement inheritance (see note in src/schema.rs), so this is next # best thing. DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 49a06e11bd..bccd3a2d7e 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -39,6 +39,7 @@ import os from deltalake._internal import ( + PyMergeBuilder, RawDeltaTable, ) from deltalake._internal import create_deltalake as _create_deltalake @@ -952,8 +953,7 @@ def merge( source.schema, (batch for batch in source) ) - return TableMerger( - self, + py_merge_builder = self._table.create_merge_builder( source=source, predicate=predicate, source_alias=source_alias, @@ -961,8 +961,11 @@ def merge( safe_cast=not error_on_type_mismatch, writer_properties=writer_properties, custom_metadata=custom_metadata, - post_commithook_properties=post_commithook_properties, + post_commithook_properties=post_commithook_properties.__dict__ + if post_commithook_properties + else None, ) + return TableMerger(py_merge_builder, self._table) def restore( self, @@ -1295,37 +1298,11 @@ class TableMerger: def __init__( self, - table: DeltaTable, - source: pyarrow.RecordBatchReader, - predicate: str, - source_alias: Optional[str] = None, - target_alias: Optional[str] = None, - safe_cast: bool = True, - writer_properties: Optional[WriterProperties] = None, - custom_metadata: Optional[Dict[str, str]] = None, - post_commithook_properties: Optional[PostCommitHookProperties] = None, + builder: PyMergeBuilder, + table: RawDeltaTable, ): - self.table = table - self.source = source - self.predicate = predicate - self.source_alias = source_alias - self.target_alias = target_alias - self.safe_cast = safe_cast - self.writer_properties = writer_properties - self.custom_metadata = custom_metadata - self.post_commithook_properties = post_commithook_properties - self.matched_update_updates: Optional[List[Dict[str, str]]] = None - self.matched_update_predicate: Optional[List[Optional[str]]] = None - self.matched_delete_predicate: Optional[List[str]] = None - self.matched_delete_all: Optional[bool] = None - self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None - self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = ( - None - ) - self.not_matched_by_source_delete_predicate: Optional[List[str]] = None - self.not_matched_by_source_delete_all: Optional[bool] = None + self._builder = builder + self._table = table def when_matched_update( self, updates: Dict[str, str], predicate: Optional[str] = None @@ -1372,14 +1349,7 @@ def when_matched_update( 2 3 6 ``` """ - if isinstance(self.matched_update_updates, list) and isinstance( - self.matched_update_predicate, list - ): - self.matched_update_updates.append(updates) - self.matched_update_predicate.append(predicate) - else: - self.matched_update_updates = [updates] - self.matched_update_predicate = [predicate] + self._builder.when_matched_update(updates, predicate) return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1424,24 +1394,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg 2 3 6 ``` """ + maybe_source_alias = self._builder.source_alias + maybe_target_alias = self._builder.target_alias - src_alias = (self.source_alias + ".") if self.source_alias is not None else "" - trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else "" + trgt_alias = ( + (maybe_target_alias + ".") if maybe_target_alias is not None else "" + ) updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" - for col in self.source.schema + for col in self._builder.arrow_schema } - if isinstance(self.matched_update_updates, list) and isinstance( - self.matched_update_predicate, list - ): - self.matched_update_updates.append(updates) - self.matched_update_predicate.append(predicate) - else: - self.matched_update_updates = [updates] - self.matched_update_predicate = [predicate] - + self._builder.when_matched_update(updates, predicate) return self def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1507,19 +1473,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": 0 1 4 ``` """ - if self.matched_delete_all is not None: - raise ValueError( - """when_matched_delete without a predicate has already been set, which means - it will delete all, any subsequent when_matched_delete, won't make sense.""" - ) - - if predicate is None: - self.matched_delete_all = True - else: - if isinstance(self.matched_delete_predicate, list): - self.matched_delete_predicate.append(predicate) - else: - self.matched_delete_predicate = [predicate] + self._builder.when_matched_delete(predicate) return self def when_not_matched_insert( @@ -1572,16 +1526,7 @@ def when_not_matched_insert( 3 4 7 ``` """ - - if isinstance(self.not_matched_insert_updates, list) and isinstance( - self.not_matched_insert_predicate, list - ): - self.not_matched_insert_updates.append(updates) - self.not_matched_insert_predicate.append(predicate) - else: - self.not_matched_insert_updates = [updates] - self.not_matched_insert_predicate = [predicate] - + self._builder.when_not_matched_insert(updates, predicate) return self def when_not_matched_insert_all( @@ -1630,22 +1575,19 @@ def when_not_matched_insert_all( 3 4 7 ``` """ + maybe_source_alias = self._builder.source_alias + maybe_target_alias = self._builder.target_alias - src_alias = (self.source_alias + ".") if self.source_alias is not None else "" - trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else "" + trgt_alias = ( + (maybe_target_alias + ".") if maybe_target_alias is not None else "" + ) updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" - for col in self.source.schema + for col in self._builder.arrow_schema } - if isinstance(self.not_matched_insert_updates, list) and isinstance( - self.not_matched_insert_predicate, list - ): - self.not_matched_insert_updates.append(updates) - self.not_matched_insert_predicate.append(predicate) - else: - self.not_matched_insert_updates = [updates] - self.not_matched_insert_predicate = [predicate] + self._builder.when_not_matched_insert(updates, predicate) return self def when_not_matched_by_source_update( @@ -1695,15 +1637,7 @@ def when_not_matched_by_source_update( 2 3 6 ``` """ - - if isinstance(self.not_matched_by_source_update_updates, list) and isinstance( - self.not_matched_by_source_update_predicate, list - ): - self.not_matched_by_source_update_updates.append(updates) - self.not_matched_by_source_update_predicate.append(predicate) - else: - self.not_matched_by_source_update_updates = [updates] - self.not_matched_by_source_update_predicate = [predicate] + self._builder.when_not_matched_by_source_update(updates, predicate) return self def when_not_matched_by_source_delete( @@ -1722,19 +1656,7 @@ def when_not_matched_by_source_delete( Returns: TableMerger: TableMerger Object """ - if self.not_matched_by_source_delete_all is not None: - raise ValueError( - """when_not_matched_by_source_delete without a predicate has already been set, which means - it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense.""" - ) - - if predicate is None: - self.not_matched_by_source_delete_all = True - else: - if isinstance(self.not_matched_by_source_delete_predicate, list): - self.not_matched_by_source_delete_predicate.append(predicate) - else: - self.not_matched_by_source_delete_predicate = [predicate] + self._builder.when_not_matched_by_source_delete(predicate) return self def execute(self) -> Dict[str, Any]: @@ -1743,31 +1665,7 @@ def execute(self) -> Dict[str, Any]: Returns: Dict: metrics """ - metrics = self.table._table.merge_execute( - source=self.source, - predicate=self.predicate, - source_alias=self.source_alias, - target_alias=self.target_alias, - safe_cast=self.safe_cast, - writer_properties=self.writer_properties - if self.writer_properties - else None, - custom_metadata=self.custom_metadata, - post_commithook_properties=self.post_commithook_properties.__dict__ - if self.post_commithook_properties - else None, - matched_update_updates=self.matched_update_updates, - matched_update_predicate=self.matched_update_predicate, - matched_delete_predicate=self.matched_delete_predicate, - matched_delete_all=self.matched_delete_all, - not_matched_insert_updates=self.not_matched_insert_updates, - not_matched_insert_predicate=self.not_matched_insert_predicate, - not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, - not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, - not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, - not_matched_by_source_delete_all=self.not_matched_by_source_delete_all, - ) - self.table.update_incremental() + metrics = self._table.merge_execute(self._builder) return json.loads(metrics) diff --git a/python/src/lib.rs b/python/src/lib.rs index 63d0bcc17d..45c3d20bd2 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,12 +1,12 @@ mod error; mod filesystem; +mod merge; mod schema; mod utils; use std::collections::{HashMap, HashSet}; use std::future::IntoFuture; use std::str::FromStr; -use std::sync::Arc; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; @@ -16,12 +16,9 @@ use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; -use deltalake::arrow::record_batch::RecordBatchReader; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::{cleanup_metadata, create_checkpoint}; -use deltalake::datafusion::catalog::TableProvider; -use deltalake::datafusion::datasource::memory::MemTable; use deltalake::datafusion::physical_plan::ExecutionPlan; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; @@ -37,7 +34,6 @@ use deltalake::operations::delete::DeleteBuilder; use deltalake::operations::drop_constraints::DropConstraintBuilder; use deltalake::operations::filesystem_check::FileSystemCheckBuilder; use deltalake::operations::load_cdf::CdfLoadBuilder; -use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::set_tbl_properties::SetTablePropertiesBuilder; @@ -55,6 +51,7 @@ use deltalake::storage::IORuntime; use deltalake::DeltaTableBuilder; use deltalake::{DeltaOps, DeltaResult}; use futures::future::join_all; + use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -64,6 +61,7 @@ use serde_json::{Map, Value}; use crate::error::DeltaProtocolError; use crate::error::PythonError; use crate::filesystem::FsConfig; +use crate::merge::PyMergeBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; @@ -682,7 +680,8 @@ impl RawDeltaTable { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (source, + #[pyo3(signature = ( + source, predicate, source_alias = None, target_alias = None, @@ -690,19 +689,9 @@ impl RawDeltaTable { writer_properties = None, post_commithook_properties = None, custom_metadata = None, - matched_update_updates = None, - matched_update_predicate = None, - matched_delete_predicate = None, - matched_delete_all = None, - not_matched_insert_updates = None, - not_matched_insert_predicate = None, - not_matched_by_source_update_updates = None, - not_matched_by_source_update_predicate = None, - not_matched_by_source_delete_predicate = None, - not_matched_by_source_delete_all = None, ))] - pub fn merge_execute( - &mut self, + pub fn create_merge_builder( + &self, py: Python, source: PyArrowType, predicate: String, @@ -712,167 +701,37 @@ impl RawDeltaTable { writer_properties: Option, post_commithook_properties: Option>>, custom_metadata: Option>, - matched_update_updates: Option>>, - matched_update_predicate: Option>>, - matched_delete_predicate: Option>, - matched_delete_all: Option, - not_matched_insert_updates: Option>>, - not_matched_insert_predicate: Option>>, - not_matched_by_source_update_updates: Option>>, - not_matched_by_source_update_predicate: Option>>, - not_matched_by_source_delete_predicate: Option>, - not_matched_by_source_delete_all: Option, - ) -> PyResult { - let (table, metrics) = py.allow_threads(|| { - let ctx = SessionContext::new(); - let schema = source.0.schema(); - let batches = vec![source.0.map(|batch| batch.unwrap()).collect::>()]; - let table_provider: Arc = - Arc::new(MemTable::try_new(schema, batches).unwrap()); - let source_df = ctx.read_table(table_provider).unwrap(); - - let mut cmd = MergeBuilder::new( + ) -> PyResult { + py.allow_threads(|| { + Ok(PyMergeBuilder::new( self._table.log_store(), self._table.snapshot().map_err(PythonError::from)?.clone(), + source.0, predicate, - source_df, + source_alias, + target_alias, + safe_cast, + writer_properties, + post_commithook_properties, + custom_metadata, ) - .with_safe_cast(safe_cast); - - if let Some(src_alias) = source_alias { - cmd = cmd.with_source_alias(src_alias); - } - - if let Some(trgt_alias) = target_alias { - cmd = cmd.with_target_alias(trgt_alias); - } - - if let Some(writer_props) = writer_properties { - cmd = cmd.with_writer_properties( - set_writer_properties(writer_props).map_err(PythonError::from)?, - ); - } - - if let Some(commit_properties) = - maybe_create_commit_properties(custom_metadata, post_commithook_properties) - { - cmd = cmd.with_commit_properties(commit_properties); - } - - if let Some(mu_updates) = matched_update_updates { - if let Some(mu_predicate) = matched_update_predicate { - for it in mu_updates.iter().zip(mu_predicate.iter()) { - let (update_values, predicate_value) = it; - - if let Some(pred) = predicate_value { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(_md_delete_all) = matched_delete_all { - cmd = cmd - .when_matched_delete(|delete| delete) - .map_err(PythonError::from)?; - } else if let Some(md_predicate) = matched_delete_predicate { - for pred in md_predicate.iter() { - cmd = cmd - .when_matched_delete(|delete| delete.predicate(pred.clone())) - .map_err(PythonError::from)?; - } - } - - if let Some(nmi_updates) = not_matched_insert_updates { - if let Some(nmi_predicate) = not_matched_insert_predicate { - for it in nmi_updates.iter().zip(nmi_predicate.iter()) { - let (update_values, predicate_value) = it; - if let Some(pred) = predicate_value { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in update_values { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in update_values { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(nmbsu_updates) = not_matched_by_source_update_updates { - if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { - for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) { - let (update_values, predicate_value) = it; - if let Some(pred) = predicate_value { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(_nmbs_delete_all) = not_matched_by_source_delete_all { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete) - .map_err(PythonError::from)?; - } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { - for pred in nmbs_predicate.iter() { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone())) - .map_err(PythonError::from)?; - } - } + .map_err(PythonError::from)?) + }) + } - rt().block_on(cmd.into_future()).map_err(PythonError::from) - })?; - self._table.state = table.state; - Ok(serde_json::to_string(&metrics).unwrap()) + #[pyo3(signature=( + merge_builder + ))] + pub fn merge_execute( + &mut self, + py: Python, + merge_builder: &mut PyMergeBuilder, + ) -> PyResult { + py.allow_threads(|| { + let (table, metrics) = merge_builder.execute().map_err(PythonError::from)?; + self._table.state = table.state; + Ok(metrics) + }) } // Run the restore command on the Delta Table: restore table to a given version or datetime @@ -2070,6 +1929,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m )?)?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; // There are issues with submodules, so we will expose them flat for now diff --git a/python/src/merge.rs b/python/src/merge.rs new file mode 100644 index 0000000000..e1bb1bf3a6 --- /dev/null +++ b/python/src/merge.rs @@ -0,0 +1,210 @@ +use deltalake::arrow::array::RecordBatchReader; +use deltalake::arrow::datatypes::Schema as ArrowSchema; +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; +use deltalake::arrow::pyarrow::IntoPyArrow; +use deltalake::datafusion::catalog::TableProvider; +use deltalake::datafusion::datasource::MemTable; +use deltalake::datafusion::prelude::SessionContext; +use deltalake::logstore::LogStoreRef; +use deltalake::operations::merge::MergeBuilder; +use deltalake::table::state::DeltaTableState; +use deltalake::{DeltaResult, DeltaTable}; +use pyo3::prelude::*; +use std::collections::HashMap; +use std::future::IntoFuture; +use std::sync::Arc; + +use crate::error::PythonError; +use crate::utils::rt; +use crate::{maybe_create_commit_properties, set_writer_properties, PyWriterProperties}; + +#[pyclass(module = "deltalake._internal")] +pub(crate) struct PyMergeBuilder { + _builder: Option, + #[pyo3(get)] + source_alias: Option, + #[pyo3(get)] + target_alias: Option, + arrow_schema: Arc, +} + +impl PyMergeBuilder { + pub fn new( + log_store: LogStoreRef, + snapshot: DeltaTableState, + source: ArrowArrayStreamReader, + predicate: String, + source_alias: Option, + target_alias: Option, + safe_cast: bool, + writer_properties: Option, + post_commithook_properties: Option>>, + custom_metadata: Option>, + ) -> DeltaResult { + let ctx = SessionContext::new(); + let schema = source.schema(); + let batches = vec![source.map(|batch| batch.unwrap()).collect::>()]; + let table_provider: Arc = + Arc::new(MemTable::try_new(schema.clone(), batches).unwrap()); + let source_df = ctx.read_table(table_provider).unwrap(); + + let mut cmd = + MergeBuilder::new(log_store, snapshot, predicate, source_df).with_safe_cast(safe_cast); + + if let Some(src_alias) = &source_alias { + cmd = cmd.with_source_alias(src_alias); + } + + if let Some(trgt_alias) = &target_alias { + cmd = cmd.with_target_alias(trgt_alias); + } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties(set_writer_properties(writer_props)?); + } + + if let Some(commit_properties) = + maybe_create_commit_properties(custom_metadata, post_commithook_properties) + { + cmd = cmd.with_commit_properties(commit_properties); + } + Ok(Self { + _builder: Some(cmd), + source_alias: source_alias, + target_alias: target_alias, + arrow_schema: schema, + }) + } + + pub fn execute(&mut self) -> DeltaResult<(DeltaTable, String)> { + let (table, metrics) = rt().block_on(self._builder.take().unwrap().into_future())?; + Ok((table, serde_json::to_string(&metrics).unwrap())) + } +} + +#[pymethods] +impl PyMergeBuilder { + #[getter] + fn get_arrow_schema(&self, py: Python) -> PyResult { + Ok(::clone(&self.arrow_schema).into_pyarrow(py)?) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_matched_update( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_matched_update(|mut update| { + for (column, expression) in updates { + update = update.update(column, expression) + } + if let Some(predicate) = predicate { + update = update.predicate(predicate) + }; + update + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + predicate = None, + ))] + fn when_matched_delete(&mut self, predicate: Option) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_matched_delete(|mut delete| { + if let Some(predicate) = predicate { + delete = delete.predicate(predicate) + }; + delete + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_not_matched_insert( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_insert(|mut insert| { + for (column, expression) in updates { + insert = insert.set(column, expression) + } + if let Some(predicate) = predicate { + insert = insert.predicate(predicate) + }; + insert + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_not_matched_by_source_update( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_by_source_update(|mut update| { + for (column, expression) in updates { + update = update.update(column, expression) + } + if let Some(predicate) = predicate { + update = update.predicate(predicate) + }; + update + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + predicate = None, + ))] + fn when_not_matched_by_source_delete(&mut self, predicate: Option) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_by_source_delete(|mut delete| { + if let Some(predicate) = predicate { + delete = delete.predicate(predicate) + }; + delete + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } +}