diff --git a/python/Cargo.toml b/python/Cargo.toml index a9936a483c..c5d30dd641 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -42,6 +42,10 @@ reqwest = { version = "*", features = ["native-tls-vendored"] } version = "0.20" features = ["extension-module", "abi3", "abi3-py38"] +[dependencies.pyo3-asyncio] +version = "0.20" +features = ["tokio-runtime"] + [dependencies.deltalake] path = "../crates/deltalake" version = "0" diff --git a/python/src/lib.rs b/python/src/lib.rs index 5741bd40d2..101aabb854 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -4,6 +4,7 @@ mod error; mod filesystem; mod schema; mod utils; +extern crate pyo3_asyncio; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; @@ -52,10 +53,15 @@ use crate::filesystem::FsConfig; use crate::schema::schema_to_pyobject; #[inline] -fn rt() -> PyResult { +fn rt_pyo3() -> PyResult { tokio::runtime::Runtime::new().map_err(|err| PyRuntimeError::new_err(err.to_string())) } +#[inline] +fn rt() -> &'static tokio::runtime::Runtime { + pyo3_asyncio::tokio::get_runtime() +} + #[derive(FromPyObject)] enum PartitionFilterValue<'a> { Single(&'a str), @@ -113,7 +119,7 @@ impl RawDeltaTable { .map_err(PythonError::from)?; } - let table = rt()?.block_on(builder.load()).map_err(PythonError::from)?; + let table = rt().block_on(builder.load()).map_err(PythonError::from)?; Ok(RawDeltaTable { _table: table, _config: FsConfig { @@ -135,7 +141,7 @@ impl RawDeltaTable { ) -> PyResult { let data_catalog = deltalake::data_catalog::get_data_catalog(data_catalog, catalog_options) .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - let table_uri = rt()? + let table_uri = rt() .block_on(data_catalog.get_table_storage_location( data_catalog_id, database_name, @@ -174,13 +180,13 @@ impl RawDeltaTable { } pub fn load_version(&mut self, version: i64) -> PyResult<()> { - Ok(rt()? + Ok(rt() .block_on(self._table.load_version(version)) .map_err(PythonError::from)?) } pub fn get_latest_version(&mut self) -> PyResult { - Ok(rt()? + Ok(rt() .block_on(self._table.get_latest_version()) .map_err(PythonError::from)?) } @@ -190,7 +196,7 @@ impl RawDeltaTable { DateTime::::from(DateTime::::parse_from_rfc3339(ds).map_err( |err| PyValueError::new_err(format!("Failed to parse datetime string: {err}")), )?); - Ok(rt()? + Ok(rt() .block_on(self._table.load_with_datetime(datetime)) .map_err(PythonError::from)?) } @@ -280,7 +286,7 @@ impl RawDeltaTable { if let Some(retention_period) = retention_hours { cmd = cmd.with_retention_period(Duration::hours(retention_period as i64)); } - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -333,7 +339,7 @@ impl RawDeltaTable { cmd = cmd.with_predicate(update_predicate); } - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -361,7 +367,7 @@ impl RawDeltaTable { .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -394,7 +400,7 @@ impl RawDeltaTable { .map_err(PythonError::from)?; cmd = cmd.with_filters(&converted_filters); - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -593,7 +599,7 @@ impl RawDeltaTable { } } - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -624,7 +630,7 @@ impl RawDeltaTable { } cmd = cmd.with_ignore_missing_files(ignore_missing_files); cmd = cmd.with_protocol_downgrade_allowed(protocol_downgrade_allowed); - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -633,7 +639,7 @@ impl RawDeltaTable { /// Run the History command on the Delta Table: Returns provenance information, including the operation, user, and so on, for each write to a table. pub fn history(&mut self, limit: Option) -> PyResult> { - let history = rt()? + let history = rt() .block_on(self._table.history(limit)) .map_err(PythonError::from)?; Ok(history @@ -643,7 +649,7 @@ impl RawDeltaTable { } pub fn update_incremental(&mut self) -> PyResult<()> { - Ok(rt()? + Ok(rt() .block_on(self._table.update_incremental(None)) .map_err(PythonError::from)?) } @@ -821,15 +827,14 @@ impl RawDeltaTable { }; let store = self._table.log_store(); - rt()? - .block_on(commit( - &*store, - &actions, - operation, - self._table.get_state(), - None, - )) - .map_err(PythonError::from)?; + rt().block_on(commit( + &*store, + &actions, + operation, + self._table.get_state(), + None, + )) + .map_err(PythonError::from)?; Ok(()) } @@ -837,23 +842,21 @@ impl RawDeltaTable { pub fn get_py_storage_backend(&self) -> PyResult { Ok(filesystem::DeltaFileSystemHandler { inner: self._table.object_store(), - rt: Arc::new(rt()?), + rt: Arc::new(rt_pyo3()?), config: self._config.clone(), known_sizes: None, }) } pub fn create_checkpoint(&self) -> PyResult<()> { - rt()? - .block_on(create_checkpoint(&self._table)) + rt().block_on(create_checkpoint(&self._table)) .map_err(PythonError::from)?; Ok(()) } pub fn cleanup_metadata(&self) -> PyResult<()> { - rt()? - .block_on(cleanup_metadata(&self._table)) + rt().block_on(cleanup_metadata(&self._table)) .map_err(PythonError::from)?; Ok(()) @@ -875,7 +878,7 @@ impl RawDeltaTable { if let Some(predicate) = predicate { cmd = cmd.with_predicate(predicate); } - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -889,7 +892,7 @@ impl RawDeltaTable { let cmd = FileSystemCheckBuilder::new(self._table.log_store(), self._table.state.clone()) .with_dry_run(dry_run); - let (table, metrics) = rt()? + let (table, metrics) = rt() .block_on(cmd.into_future()) .map_err(PythonError::from)?; self._table.state = table.state; @@ -1076,7 +1079,7 @@ fn batch_distinct(batch: PyArrowType) -> PyResult