From b4fd5cf33183ad620db2e9afe9a20c1b89bc74e1 Mon Sep 17 00:00:00 2001 From: helanto Date: Wed, 11 Sep 2024 19:59:30 +0300 Subject: [PATCH] chore: Create pyo3 class for CommitProperties --- python/deltalake/_internal.pyi | 49 +++---- python/deltalake/table.py | 102 ++++++++------ python/deltalake/writer.py | 18 +-- python/src/lib.rs | 234 +++++++++++++-------------------- python/src/merge.rs | 15 +-- 5 files changed, 187 insertions(+), 231 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 6bf80a8a03..b7fc21f484 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -3,7 +3,12 @@ from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union import pyarrow import pyarrow.fs as fs -from deltalake.writer import AddAction, PostCommitHookProperties, WriterProperties +from deltalake.writer import ( + AddAction, + CommitProperties, + PostCommitHookProperties, + WriterProperties, +) __version__: str @@ -57,9 +62,8 @@ class RawDeltaTable: dry_run: bool, retention_hours: Optional[int], enforce_retention_duration: bool, - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> List[str]: ... def compact_optimize( self, @@ -68,9 +72,8 @@ class RawDeltaTable: max_concurrent_tasks: Optional[int], min_commit_interval: Optional[int], writer_properties: Optional[WriterProperties], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> str: ... def z_order_optimize( self, @@ -81,46 +84,40 @@ class RawDeltaTable: max_spill_size: Optional[int], min_commit_interval: Optional[int], writer_properties: Optional[WriterProperties], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> str: ... def add_columns( self, fields: List[Field], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> None: ... def add_constraints( self, constraints: Dict[str, str], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> None: ... def drop_constraints( self, name: str, raise_if_not_exists: bool, - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> None: ... def set_table_properties( self, properties: Dict[str, str], raise_if_not_exists: bool, - custom_metadata: Optional[Dict[str, str]], - max_commit_retries: Optional[int], + commit_properties: Optional[CommitProperties], ) -> None: ... def restore( self, target: Optional[Any], ignore_missing_files: bool, protocol_downgrade_allowed: bool, - custom_metadata: Optional[Dict[str, str]], - max_commit_retries: Optional[int], + commit_properties: Optional[CommitProperties], ) -> str: ... def history(self, limit: Optional[int]) -> List[str]: ... def update_incremental(self) -> None: ... @@ -133,16 +130,14 @@ class RawDeltaTable: self, predicate: Optional[str], writer_properties: Optional[WriterProperties], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> str: ... def repair( self, dry_run: bool, - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> str: ... def update( self, @@ -150,9 +145,8 @@ class RawDeltaTable: predicate: Optional[str], writer_properties: Optional[WriterProperties], safe_cast: bool, - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> str: ... def create_merge_builder( self, @@ -161,10 +155,9 @@ class RawDeltaTable: source_alias: Optional[str], target_alias: Optional[str], writer_properties: Optional[WriterProperties], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], safe_cast: bool, - max_commit_retries: Optional[int], ) -> PyMergeBuilder: ... def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ... def get_active_partitions( @@ -177,9 +170,8 @@ class RawDeltaTable: partition_by: List[str], schema: pyarrow.Schema, partitions_filters: Optional[FilterType], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> None: ... def cleanup_metadata(self) -> None: ... def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ... @@ -219,9 +211,8 @@ def write_to_deltalake( configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], writer_properties: Optional[WriterProperties], - custom_metadata: Optional[Dict[str, str]], + commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], - max_commit_retries: Optional[int], ) -> None: ... def convert_to_deltalake( uri: str, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 72ffd9e3ec..11a8baa0dd 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -167,6 +167,17 @@ def __init__( self.max_commit_retries = max_commit_retries +def _commit_properties_from_custom_metadata( + maybe_properties: Optional[CommitProperties], custom_metadata: Dict[str, str] +) -> CommitProperties: + if maybe_properties is not None: + if maybe_properties.custom_metadata is None: + maybe_properties.custom_metadata = custom_metadata + return maybe_properties + return maybe_properties + return CommitProperties(custom_metadata=custom_metadata) + + @dataclass(init=True) class BloomFilterProperties: """The Bloom Filter Properties instance for the Rust parquet writer.""" @@ -784,6 +795,9 @@ def vacuum( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if retention_hours: if retention_hours < 0: @@ -793,9 +807,8 @@ def vacuum( dry_run, retention_hours, enforce_retention_duration, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) def update( @@ -870,6 +883,9 @@ def update( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if updates is None and new_values is not None: updates = {} @@ -907,13 +923,8 @@ def update( predicate, writer_properties, safe_cast=not error_on_type_mismatch, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, + commit_properties=commit_properties, post_commithook_properties=post_commithook_properties, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, ) return json.loads(metrics) @@ -984,6 +995,9 @@ def merge( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if large_dtypes: warnings.warn( @@ -1028,13 +1042,8 @@ def merge( target_alias=target_alias, safe_cast=not error_on_type_mismatch, writer_properties=writer_properties, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, + commit_properties=commit_properties, post_commithook_properties=post_commithook_properties, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, ) return TableMerger(py_merge_builder, self._table) @@ -1066,30 +1075,23 @@ def restore( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if isinstance(target, datetime): metrics = self._table.restore( target.isoformat(), ignore_missing_files=ignore_missing_files, protocol_downgrade_allowed=protocol_downgrade_allowed, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, + commit_properties=commit_properties, ) else: metrics = self._table.restore( target, ignore_missing_files=ignore_missing_files, protocol_downgrade_allowed=protocol_downgrade_allowed, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, + commit_properties=commit_properties, ) return json.loads(metrics) @@ -1343,13 +1345,15 @@ def delete( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) metrics = self._table.delete( predicate, writer_properties, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) return json.loads(metrics) @@ -1393,12 +1397,14 @@ def repair( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) metrics = self._table.repair( dry_run, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) return json.loads(metrics) @@ -1820,15 +1826,17 @@ def add_columns( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if isinstance(fields, DeltaField): fields = [fields] self.table._table.add_columns( fields, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) def add_constraint( @@ -1868,6 +1876,9 @@ def add_constraint( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if len(constraints.keys()) > 1: raise ValueError( @@ -1877,9 +1888,8 @@ def add_constraint( self.table._table.add_constraints( constraints, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) def drop_constraint( @@ -1925,13 +1935,15 @@ def drop_constraint( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) self.table._table.drop_constraints( name, raise_if_not_exists, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) def set_table_properties( @@ -1971,12 +1983,14 @@ def set_table_properties( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) self.table._table.set_table_properties( properties, raise_if_not_exists, - commit_properties.custom_metadata if commit_properties else custom_metadata, - commit_properties.max_commit_retries if commit_properties else None, + commit_properties, ) @@ -2047,6 +2061,9 @@ def compact( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if isinstance(min_commit_interval, timedelta): min_commit_interval = int(min_commit_interval.total_seconds()) @@ -2057,9 +2074,8 @@ def compact( max_concurrent_tasks, min_commit_interval, writer_properties, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) self.table.update_incremental() return json.loads(metrics) @@ -2125,6 +2141,9 @@ def z_order( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) if isinstance(min_commit_interval, timedelta): min_commit_interval = int(min_commit_interval.total_seconds()) @@ -2137,9 +2156,8 @@ def z_order( max_spill_size, min_commit_interval, writer_properties, - commit_properties.custom_metadata if commit_properties else custom_metadata, + commit_properties, post_commithook_properties, - commit_properties.max_commit_retries if commit_properties else None, ) self.table.update_incremental() return json.loads(metrics) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index fdadb59570..535a6e7a13 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -56,6 +56,7 @@ DeltaTable, PostCommitHookProperties, WriterProperties, + _commit_properties_from_custom_metadata, ) try: @@ -290,6 +291,9 @@ def write_deltalake( category=DeprecationWarning, stacklevel=2, ) + commit_properties = _commit_properties_from_custom_metadata( + commit_properties, custom_metadata + ) table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) if table is not None: @@ -330,13 +334,8 @@ def write_deltalake( configuration=configuration, storage_options=storage_options, writer_properties=writer_properties, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, + commit_properties=commit_properties, post_commithook_properties=post_commithook_properties, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, ) if table: table.update_incremental() @@ -570,13 +569,8 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: partition_by or [], schema, partition_filters, - custom_metadata=commit_properties.custom_metadata - if commit_properties - else custom_metadata, + commit_properties=commit_properties, post_commithook_properties=post_commithook_properties, - max_commit_retries=commit_properties.max_commit_retries - if commit_properties - else None, ) table.update_incremental() else: diff --git a/python/src/lib.rs b/python/src/lib.rs index 3ed38dd8cd..fc1c18c880 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -321,16 +321,15 @@ impl RawDeltaTable { /// Run the Vacuum command on the Delta Table: list and delete files no longer referenced /// by the Delta table and are older than the retention threshold. - #[pyo3(signature = (dry_run, retention_hours = None, enforce_retention_duration = true, custom_metadata=None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (dry_run, retention_hours = None, enforce_retention_duration = true, commit_properties=None, post_commithook_properties=None))] pub fn vacuum( &mut self, py: Python, dry_run: bool, retention_hours: Option, enforce_retention_duration: bool, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult> { let (table, metrics) = py.allow_threads(|| { let mut cmd = VacuumBuilder::new( @@ -343,11 +342,9 @@ impl RawDeltaTable { cmd = cmd.with_retention_period(Duration::hours(retention_period as i64)); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } rt().block_on(cmd.into_future()).map_err(PythonError::from) @@ -357,7 +354,7 @@ impl RawDeltaTable { } /// Run the UPDATE command on the Delta Table - #[pyo3(signature = (updates, predicate=None, writer_properties=None, safe_cast = false, custom_metadata = None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (updates, predicate=None, writer_properties=None, safe_cast = false, commit_properties = None, post_commithook_properties=None))] #[allow(clippy::too_many_arguments)] pub fn update( &mut self, @@ -366,9 +363,8 @@ impl RawDeltaTable { predicate: Option, writer_properties: Option, safe_cast: bool, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = UpdateBuilder::new( @@ -391,11 +387,9 @@ impl RawDeltaTable { cmd = cmd.with_predicate(update_predicate); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -413,9 +407,8 @@ impl RawDeltaTable { max_concurrent_tasks = None, min_commit_interval = None, writer_properties=None, - custom_metadata=None, - post_commithook_properties=None, - max_commit_retries=None, + commit_properties=None, + post_commithook_properties=None ))] #[allow(clippy::too_many_arguments)] pub fn compact_optimize( @@ -426,9 +419,8 @@ impl RawDeltaTable { max_concurrent_tasks: Option, min_commit_interval: Option, writer_properties: Option, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = OptimizeBuilder::new( @@ -449,11 +441,9 @@ impl RawDeltaTable { ); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -477,9 +467,8 @@ impl RawDeltaTable { max_spill_size = 20 * 1024 * 1024 * 1024, min_commit_interval = None, writer_properties=None, - custom_metadata=None, - post_commithook_properties=None, - max_commit_retries=None))] + commit_properties=None, + post_commithook_properties=None))] pub fn z_order_optimize( &mut self, py: Python, @@ -490,9 +479,8 @@ impl RawDeltaTable { max_spill_size: usize, min_commit_interval: Option, writer_properties: Option, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = OptimizeBuilder::new( @@ -515,11 +503,9 @@ impl RawDeltaTable { ); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -534,14 +520,13 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } - #[pyo3(signature = (fields, custom_metadata=None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (fields, commit_properties=None, post_commithook_properties=None))] pub fn add_columns( &mut self, py: Python, fields: Vec, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = AddColumnBuilder::new( @@ -556,11 +541,9 @@ impl RawDeltaTable { cmd = cmd.with_fields(new_fields); - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -570,14 +553,13 @@ impl RawDeltaTable { Ok(()) } - #[pyo3(signature = (constraints, custom_metadata=None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (constraints, commit_properties=None, post_commithook_properties=None))] pub fn add_constraints( &mut self, py: Python, constraints: HashMap, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = ConstraintBuilder::new( @@ -589,11 +571,9 @@ impl RawDeltaTable { cmd = cmd.with_constraint(col_name.clone(), expression.clone()); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -603,15 +583,14 @@ impl RawDeltaTable { Ok(()) } - #[pyo3(signature = (name, raise_if_not_exists, custom_metadata=None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (name, raise_if_not_exists, commit_properties=None, post_commithook_properties=None))] pub fn drop_constraints( &mut self, py: Python, name: String, raise_if_not_exists: bool, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = DropConstraintBuilder::new( @@ -621,11 +600,9 @@ impl RawDeltaTable { .with_constraint(name) .with_raise_if_not_exists(raise_if_not_exists); - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -711,8 +688,7 @@ impl RawDeltaTable { safe_cast = false, writer_properties = None, post_commithook_properties = None, - custom_metadata = None, - max_commit_retries=None, + commit_properties = None, ))] pub fn create_merge_builder( &self, @@ -724,8 +700,7 @@ impl RawDeltaTable { safe_cast: bool, writer_properties: Option, post_commithook_properties: Option, - custom_metadata: Option>, - max_commit_retries: Option, + commit_properties: Option, ) -> PyResult { py.allow_threads(|| { Ok(PyMergeBuilder::new( @@ -738,8 +713,7 @@ impl RawDeltaTable { safe_cast, writer_properties, post_commithook_properties, - custom_metadata, - max_commit_retries, + commit_properties, ) .map_err(PythonError::from)?) }) @@ -761,14 +735,13 @@ impl RawDeltaTable { } // Run the restore command on the Delta Table: restore table to a given version or datetime - #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false, custom_metadata=None, max_commit_retries=None))] + #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false, commit_properties=None))] pub fn restore( &mut self, target: Option<&Bound<'_, PyAny>>, ignore_missing_files: bool, protocol_downgrade_allowed: bool, - custom_metadata: Option>, - max_commit_retries: Option, + commit_properties: Option, ) -> PyResult { let mut cmd = RestoreBuilder::new( self._table.log_store(), @@ -790,9 +763,7 @@ impl RawDeltaTable { cmd = cmd.with_ignore_missing_files(ignore_missing_files); cmd = cmd.with_protocol_downgrade_allowed(protocol_downgrade_allowed); - if let Some(commit_properties) = - maybe_create_commit_properties(custom_metadata, max_commit_retries, None) - { + if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, None) { cmd = cmd.with_commit_properties(commit_properties); } @@ -953,9 +924,8 @@ impl RawDeltaTable { partition_by: Vec, schema: PyArrowType, partitions_filters: Option>, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult<()> { py.allow_threads(|| { let mode = mode.parse().map_err(PythonError::from)?; @@ -1039,24 +1009,25 @@ impl RawDeltaTable { predicate: None, }; - let mut commit_properties = CommitProperties::default(); - if let Some(metadata) = custom_metadata { - let json_metadata: Map = - metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); - commit_properties = commit_properties.with_metadata(json_metadata); - }; - - if let Some(max_retries) = max_commit_retries { - commit_properties = commit_properties.with_max_retries(max_retries); - }; + let mut properties = CommitProperties::default(); + if let Some(props) = commit_properties { + if let Some(metadata) = props.custom_metadata { + let json_metadata: Map = + metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); + properties = properties.with_metadata(json_metadata); + }; + + if let Some(max_retries) = props.max_commit_retries { + properties = properties.with_max_retries(max_retries); + }; + } if let Some(post_commit_hook_props) = post_commithook_properties { - commit_properties = - set_post_commithook_properties(commit_properties, post_commit_hook_props) + properties = set_post_commithook_properties(properties, post_commit_hook_props) } rt().block_on( - CommitBuilder::from(commit_properties) + CommitBuilder::from(properties) .with_actions(actions) .build( Some(self._table.snapshot().map_err(PythonError::from)?), @@ -1122,15 +1093,14 @@ impl RawDeltaTable { .collect::>()) } /// Run the delete command on the delta table: delete records following a predicate and return the delete metrics. - #[pyo3(signature = (predicate = None, writer_properties=None, custom_metadata=None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (predicate = None, writer_properties=None, commit_properties=None, post_commithook_properties=None))] pub fn delete( &mut self, py: Python, predicate: Option, writer_properties: Option, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = DeleteBuilder::new( @@ -1145,11 +1115,9 @@ impl RawDeltaTable { set_writer_properties(writer_props).map_err(PythonError::from)?, ); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -1159,13 +1127,12 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } - #[pyo3(signature = (properties, raise_if_not_exists, custom_metadata=None, max_commit_retries=None))] + #[pyo3(signature = (properties, raise_if_not_exists, commit_properties=None))] pub fn set_table_properties( &mut self, properties: HashMap, raise_if_not_exists: bool, - custom_metadata: Option>, - max_commit_retries: Option, + commit_properties: Option, ) -> PyResult<()> { let mut cmd = SetTablePropertiesBuilder::new( self._table.log_store(), @@ -1174,9 +1141,7 @@ impl RawDeltaTable { .with_properties(properties) .with_raise_if_not_exists(raise_if_not_exists); - if let Some(commit_properties) = - maybe_create_commit_properties(custom_metadata, max_commit_retries, None) - { + if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, None) { cmd = cmd.with_commit_properties(commit_properties); } @@ -1189,13 +1154,12 @@ impl RawDeltaTable { /// Execute the File System Check command (FSCK) on the delta table: removes old reference to files that /// have been deleted or are malformed - #[pyo3(signature = (dry_run = true, custom_metadata = None, post_commithook_properties=None, max_commit_retries=None))] + #[pyo3(signature = (dry_run = true, commit_properties = None, post_commithook_properties=None))] pub fn repair( &mut self, dry_run: bool, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult { let mut cmd = FileSystemCheckBuilder::new( self._table.log_store(), @@ -1203,11 +1167,9 @@ impl RawDeltaTable { ) .with_dry_run(dry_run); - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } @@ -1345,23 +1307,25 @@ fn convert_partition_filters( } fn maybe_create_commit_properties( - custom_metadata: Option>, - max_commit_retries: Option, + maybe_commit_properties: Option, post_commithook_properties: Option, ) -> Option { - if custom_metadata.is_none() && post_commithook_properties.is_none() { + if maybe_commit_properties.is_none() && post_commithook_properties.is_none() { return None; } let mut commit_properties = CommitProperties::default(); - if let Some(metadata) = custom_metadata { - let json_metadata: Map = - metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); - commit_properties = commit_properties.with_metadata(json_metadata); - }; - if let Some(max_retries) = max_commit_retries { - commit_properties = commit_properties.with_max_retries(max_retries); - }; + if let Some(commit_props) = maybe_commit_properties { + if let Some(metadata) = commit_props.custom_metadata { + let json_metadata: Map = + metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); + commit_properties = commit_properties.with_metadata(json_metadata); + }; + + if let Some(max_retries) = commit_props.max_commit_retries { + commit_properties = commit_properties.with_max_retries(max_retries); + }; + } if let Some(post_commit_hook_props) = post_commithook_properties { commit_properties = @@ -1639,6 +1603,12 @@ pub struct PyPostCommitHookProperties { cleanup_expired_logs: Option, } +#[derive(FromPyObject)] +pub struct PyCommitProperties { + custom_metadata: Option>, + max_commit_retries: Option, +} + #[pyfunction] #[allow(clippy::too_many_arguments)] fn write_to_deltalake( @@ -1656,9 +1626,8 @@ fn write_to_deltalake( configuration: Option>>, storage_options: Option>, writer_properties: Option, - custom_metadata: Option>, + commit_properties: Option, post_commithook_properties: Option, - max_commit_retries: Option, ) -> PyResult<()> { py.allow_threads(|| { let batches = data.0.map(|batch| batch.unwrap()).collect::>(); @@ -1708,24 +1677,11 @@ fn write_to_deltalake( builder = builder.with_configuration(config); }; - if custom_metadata.is_some() || post_commithook_properties.is_some() { - let mut commit_properties = CommitProperties::default(); - if let Some(metadata) = custom_metadata { - let json_metadata: Map = - metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); - commit_properties = commit_properties.with_metadata(json_metadata); - }; - - if let Some(max_retries) = max_commit_retries { - commit_properties = commit_properties.with_max_retries(max_retries); - }; - - if let Some(post_commit_hook_props) = post_commithook_properties { - commit_properties = - set_post_commithook_properties(commit_properties, post_commit_hook_props) - } + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { builder = builder.with_commit_properties(commit_properties); - } + }; rt().block_on(builder.into_future()) .map_err(PythonError::from)?; diff --git a/python/src/merge.rs b/python/src/merge.rs index fd5e95081c..e1e427f46d 100644 --- a/python/src/merge.rs +++ b/python/src/merge.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use crate::error::PythonError; use crate::utils::rt; use crate::{ - maybe_create_commit_properties, set_writer_properties, PyPostCommitHookProperties, - PyWriterProperties, + maybe_create_commit_properties, set_writer_properties, PyCommitProperties, + PyPostCommitHookProperties, PyWriterProperties, }; #[pyclass(module = "deltalake._internal")] @@ -43,8 +43,7 @@ impl PyMergeBuilder { safe_cast: bool, writer_properties: Option, post_commithook_properties: Option, - custom_metadata: Option>, - max_commit_retries: Option, + commit_properties: Option, ) -> DeltaResult { let ctx = SessionContext::new(); let schema = source.schema(); @@ -68,11 +67,9 @@ impl PyMergeBuilder { cmd = cmd.with_writer_properties(set_writer_properties(writer_props)?); } - if let Some(commit_properties) = maybe_create_commit_properties( - custom_metadata, - max_commit_retries, - post_commithook_properties, - ) { + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { cmd = cmd.with_commit_properties(commit_properties); } Ok(Self {