From 457bdeec4bb408a1c8af709a8178fcd21c6faa9f Mon Sep 17 00:00:00 2001 From: Peter Ke Date: Mon, 14 Oct 2024 14:49:26 -0700 Subject: [PATCH 1/2] expose transactions in python --- python/deltalake/__init__.py | 1 + python/deltalake/_internal.pyi | 1 + python/deltalake/table.py | 15 ++++++++++++++ python/src/lib.rs | 38 +++++++++++++++++++++++++++++++++- python/tests/test_writer.py | 31 ++++++++++++++++++++++++++- 5 files changed, 84 insertions(+), 2 deletions(-) diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index 607a5d988b..43997076b2 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -15,6 +15,7 @@ from .table import DeltaTable as DeltaTable from .table import Metadata as Metadata from .table import PostCommitHookProperties as PostCommitHookProperties +from .table import Transaction as Transaction from .table import ( WriterProperties as WriterProperties, ) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 8329dddad9..41e0bb5196 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -221,6 +221,7 @@ class RawDeltaTable: starting_timestamp: Optional[str] = None, ending_timestamp: Optional[str] = None, ) -> pyarrow.RecordBatchReader: ... + def transaction_versions(self) -> Dict[str, str]: ... def rust_core_version() -> str: ... def write_new_deltalake( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 9150be697c..f3c9e3bf6f 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -150,6 +150,13 @@ def __init__( self.cleanup_expired_logs = cleanup_expired_logs +@dataclass +class Transaction: + app_id: str + version: int + last_updated: Optional[int] = None + + @dataclass(init=True) class CommitProperties: """The commit properties. Controls the behaviour of the commit.""" @@ -158,6 +165,7 @@ def __init__( self, custom_metadata: Optional[Dict[str, str]] = None, max_commit_retries: Optional[int] = None, + app_transactions: Optional[List[Transaction]] = None, ): """Custom metadata to be stored in the commit. Controls the number of retries for the commit. @@ -167,6 +175,7 @@ def __init__( """ self.custom_metadata = custom_metadata self.max_commit_retries = max_commit_retries + self.app_transactions = app_transactions def _commit_properties_from_custom_metadata( @@ -1417,6 +1426,12 @@ def repair( ) return json.loads(metrics) + def transaction_versions(self) -> Dict[str, Dict[str, Any]]: + return { + app_id: json.loads(transaction) + for app_id, transaction in self._table.transaction_versions().items() + } + class TableMerger: """API for various table `MERGE` commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index 473f5ceea9..d48fab5a5b 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -25,7 +25,7 @@ use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; use deltalake::kernel::{ - scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, + scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction, }; use deltalake::operations::add_column::AddColumnBuilder; use deltalake::operations::add_feature::AddTableFeatureBuilder; @@ -1232,6 +1232,19 @@ impl RawDeltaTable { self._table.state = table.state; Ok(serde_json::to_string(&metrics).unwrap()) } + + pub fn transaction_versions(&self) -> HashMap { + self._table + .get_app_transaction_version() + .iter() + .map(|(app_id, transaction)| { + ( + app_id.to_owned(), + serde_json::to_string(transaction).unwrap(), + ) + }) + .collect() + } } fn set_post_commithook_properties( @@ -1378,6 +1391,11 @@ fn maybe_create_commit_properties( if let Some(max_retries) = commit_props.max_commit_retries { commit_properties = commit_properties.with_max_retries(max_retries); }; + + if let Some(app_transactions) = commit_props.app_transactions { + let app_transactions = app_transactions.iter().map(Transaction::from).collect(); + commit_properties = commit_properties.with_application_transactions(app_transactions); + } } if let Some(post_commit_hook_props) = post_commithook_properties { @@ -1656,10 +1674,28 @@ pub struct PyPostCommitHookProperties { cleanup_expired_logs: Option, } +#[derive(FromPyObject)] +pub struct PyTransaction { + app_id: String, + version: i64, + last_updated: Option, +} + +impl From<&PyTransaction> for Transaction { + fn from(value: &PyTransaction) -> Self { + Transaction { + app_id: value.app_id.clone(), + version: value.version, + last_updated: value.last_updated, + } + } +} + #[derive(FromPyObject)] pub struct PyCommitProperties { custom_metadata: Option>, max_commit_retries: Option, + app_transactions: Option>, } #[pyfunction] diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index bcbd93ecc2..2ee1770e62 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -23,7 +23,7 @@ DeltaProtocolError, SchemaMismatchError, ) -from deltalake.table import ProtocolVersions +from deltalake.table import CommitProperties, ProtocolVersions, Transaction from deltalake.writer import try_get_table_and_table_uri try: @@ -1993,3 +1993,32 @@ def test_write_timestamp(tmp_path: pathlib.Path): # Now that a datetime has been passed through the writer version needs to # be upgraded to 7 to support timestampNtz assert protocol.min_writer_version == 2 + + +def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): + transactions = [ + Transaction(app_id="app_1", version=1), + Transaction(app_id="app_2", version=2, last_updated=123456), + ] + commit_properties = CommitProperties(app_transactions=transactions) + write_deltalake( + table_or_uri=tmp_path, + data=sample_data, + mode="overwrite", + schema_mode="overwrite", + commit_properties=commit_properties, + ) + + delta_table = DeltaTable(tmp_path) + transactions = delta_table.transaction_versions() + + assert len(transactions) == 2 + assert transactions["app_1"] == { + "appId": "app_1", + "version": 1, + } + assert transactions["app_2"] == { + "appId": "app_2", + "version": 2, + "lastUpdated": 123456, + } From eb57bf67d4f518432bd5ae12c085dfe928e88da0 Mon Sep 17 00:00:00 2001 From: Peter Ke Date: Tue, 15 Oct 2024 11:14:39 -0700 Subject: [PATCH 2/2] update to return pytransaction --- python/deltalake/_internal.pyi | 11 ++++++- python/deltalake/table.py | 15 ++------- python/src/lib.rs | 57 +++++++++++++++++++++++++++------- python/tests/test_writer.py | 23 +++++++------- 4 files changed, 70 insertions(+), 36 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 41e0bb5196..66b5dc8f8f 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -221,7 +221,7 @@ class RawDeltaTable: starting_timestamp: Optional[str] = None, ending_timestamp: Optional[str] = None, ) -> pyarrow.RecordBatchReader: ... - def transaction_versions(self) -> Dict[str, str]: ... + def transaction_versions(self) -> Dict[str, Transaction]: ... def rust_core_version() -> str: ... def write_new_deltalake( @@ -907,3 +907,12 @@ FilterConjunctionType = List[FilterLiteralType] FilterDNFType = List[FilterConjunctionType] FilterType = Union[FilterConjunctionType, FilterDNFType] PartitionFilterType = List[Tuple[str, str, Union[str, List[str]]]] + +class Transaction: + app_id: str + version: int + last_updated: Optional[int] + + def __init__( + self, app_id: str, version: int, last_updated: Optional[int] = None + ) -> None: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f3c9e3bf6f..e54a1c3f8c 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -43,6 +43,7 @@ PyMergeBuilder, RawDeltaTable, TableFeatures, + Transaction, ) from deltalake._internal import create_deltalake as _create_deltalake from deltalake._util import encode_partition_value @@ -150,13 +151,6 @@ def __init__( self.cleanup_expired_logs = cleanup_expired_logs -@dataclass -class Transaction: - app_id: str - version: int - last_updated: Optional[int] = None - - @dataclass(init=True) class CommitProperties: """The commit properties. Controls the behaviour of the commit.""" @@ -1426,11 +1420,8 @@ def repair( ) return json.loads(metrics) - def transaction_versions(self) -> Dict[str, Dict[str, Any]]: - return { - app_id: json.loads(transaction) - for app_id, transaction in self._table.transaction_versions().items() - } + def transaction_versions(self) -> Dict[str, Transaction]: + return self._table.transaction_versions() class TableMerger: diff --git a/python/src/lib.rs b/python/src/lib.rs index d48fab5a5b..005076c719 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1233,16 +1233,11 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } - pub fn transaction_versions(&self) -> HashMap { + pub fn transaction_versions(&self) -> HashMap { self._table .get_app_transaction_version() - .iter() - .map(|(app_id, transaction)| { - ( - app_id.to_owned(), - serde_json::to_string(transaction).unwrap(), - ) - }) + .into_iter() + .map(|(app_id, transaction)| (app_id, PyTransaction::from(transaction))) .collect() } } @@ -1674,11 +1669,48 @@ pub struct PyPostCommitHookProperties { cleanup_expired_logs: Option, } -#[derive(FromPyObject)] +#[derive(Clone)] +#[pyclass(name = "Transaction", module = "deltalake._internal")] pub struct PyTransaction { - app_id: String, - version: i64, - last_updated: Option, + #[pyo3(get)] + pub app_id: String, + #[pyo3(get)] + pub version: i64, + #[pyo3(get)] + pub last_updated: Option, +} + +#[pymethods] +impl PyTransaction { + #[new] + #[pyo3(signature = (app_id, version, last_updated = None))] + fn new(app_id: String, version: i64, last_updated: Option) -> Self { + Self { + app_id, + version, + last_updated, + } + } + + fn __repr__(&self) -> String { + format!( + "Transaction(app_id={}, version={}, last_updated={})", + self.app_id, + self.version, + self.last_updated + .map_or("None".to_owned(), |n| n.to_string()) + ) + } +} + +impl From for PyTransaction { + fn from(value: Transaction) -> Self { + PyTransaction { + app_id: value.app_id, + version: value.version, + last_updated: value.last_updated, + } + } } impl From<&PyTransaction> for Transaction { @@ -2039,6 +2071,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; // There are issues with submodules, so we will expose them flat for now // See also: https://github.com/PyO3/pyo3/issues/759 m.add_class::()?; diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 2ee1770e62..c43e5d1136 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1996,11 +1996,11 @@ def test_write_timestamp(tmp_path: pathlib.Path): def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): - transactions = [ + expected_transactions = [ Transaction(app_id="app_1", version=1), Transaction(app_id="app_2", version=2, last_updated=123456), ] - commit_properties = CommitProperties(app_transactions=transactions) + commit_properties = CommitProperties(app_transactions=expected_transactions) write_deltalake( table_or_uri=tmp_path, data=sample_data, @@ -2013,12 +2013,13 @@ def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): transactions = delta_table.transaction_versions() assert len(transactions) == 2 - assert transactions["app_1"] == { - "appId": "app_1", - "version": 1, - } - assert transactions["app_2"] == { - "appId": "app_2", - "version": 2, - "lastUpdated": 123456, - } + + transaction_1 = transactions["app_1"] + assert transaction_1.app_id == "app_1" + assert transaction_1.version == 1 + assert transaction_1.last_updated is None + + transaction_2 = transactions["app_2"] + assert transaction_2.app_id == "app_2" + assert transaction_2.version == 2 + assert transaction_2.last_updated == 123456