diff --git a/Cargo.toml b/Cargo.toml index 7285cf3eb..78b46e449 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ substrait = ["dep:datafusion-substrait"] [dependencies] tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.8" -pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38", "gil-refs"] } +pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] } arrow = { version = "52", feature = ["pyarrow"] } datafusion = { version = "39.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } datafusion-common = { version = "39.0.0", features = ["pyarrow"] } @@ -67,3 +67,4 @@ crate-type = ["cdylib", "rlib"] [profile.release] lto = true codegen-units = 1 + \ No newline at end of file diff --git a/src/common.rs b/src/common.rs index 682639aca..44c557ce7 100644 --- a/src/common.rs +++ b/src/common.rs @@ -23,7 +23,7 @@ pub mod function; pub mod schema; /// Initializes the `common` module to match the pattern of `datafusion-common` https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/config.rs b/src/config.rs index 228f95a0b..82a4f93ab 100644 --- a/src/config.rs +++ b/src/config.rs @@ -65,7 +65,7 @@ impl PyConfig { /// Get all configuration options pub fn get_all(&mut self, py: Python) -> PyResult { - let dict = PyDict::new(py); + let dict = PyDict::new_bound(py); let options = self.config.to_owned(); for entry in options.entries() { dict.set_item(entry.key, entry.value.clone().into_py(py))?; diff --git a/src/context.rs b/src/context.rs index 9462d0b86..ec63adb26 100644 --- a/src/context.rs +++ b/src/context.rs @@ -291,11 +291,11 @@ impl PySessionContext { pub fn register_object_store( &mut self, scheme: &str, - store: &PyAny, + store: &Bound<'_, PyAny>, host: Option<&str>, ) -> PyResult<()> { let res: Result<(Arc, String), PyErr> = - match StorageContexts::extract(store) { + match StorageContexts::extract_bound(store) { Ok(store) => match store { StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)), StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)), @@ -443,8 +443,8 @@ impl PySessionContext { ) -> PyResult { Python::with_gil(|py| { // Instantiate pyarrow Table object & convert to Arrow Table - let table_class = py.import("pyarrow")?.getattr("Table")?; - let args = PyTuple::new(py, &[data]); + let table_class = py.import_bound("pyarrow")?.getattr("Table")?; + let args = PyTuple::new_bound(py, &[data]); let table = table_class.call_method1("from_pylist", args)?.into(); // Convert Arrow Table to datafusion DataFrame @@ -463,8 +463,8 @@ impl PySessionContext { ) -> PyResult { Python::with_gil(|py| { // Instantiate pyarrow Table object & convert to Arrow Table - let table_class = py.import("pyarrow")?.getattr("Table")?; - let args = PyTuple::new(py, &[data]); + let table_class = py.import_bound("pyarrow")?.getattr("Table")?; + let args = PyTuple::new_bound(py, &[data]); let table = table_class.call_method1("from_pydict", args)?.into(); // Convert Arrow Table to datafusion DataFrame @@ -507,8 +507,8 @@ impl PySessionContext { ) -> PyResult { Python::with_gil(|py| { // Instantiate pyarrow Table object & convert to Arrow Table - let table_class = py.import("pyarrow")?.getattr("Table")?; - let args = PyTuple::new(py, &[data]); + let table_class = py.import_bound("pyarrow")?.getattr("Table")?; + let args = PyTuple::new_bound(py, &[data]); let table = table_class.call_method1("from_pandas", args)?.into(); // Convert Arrow Table to datafusion DataFrame @@ -710,7 +710,12 @@ impl PySessionContext { } // Registers a PyArrow.Dataset - pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { + pub fn register_dataset( + &self, + name: &str, + dataset: &Bound<'_, PyAny>, + py: Python, + ) -> PyResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?); self.ctx diff --git a/src/dataframe.rs b/src/dataframe.rs index 8f4514398..1b91067d5 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -28,6 +28,7 @@ use datafusion::prelude::*; use datafusion_common::UnnestOptions; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; use pyo3::types::PyTuple; use tokio::task::JoinHandle; @@ -56,23 +57,25 @@ impl PyDataFrame { #[pymethods] impl PyDataFrame { - fn __getitem__(&self, key: PyObject) -> PyResult { - Python::with_gil(|py| { - if let Ok(key) = key.extract::<&str>(py) { - self.select_columns(vec![key]) - } else if let Ok(tuple) = key.extract::<&PyTuple>(py) { - let keys = tuple - .iter() - .map(|item| item.extract::<&str>()) - .collect::>>()?; - self.select_columns(keys) - } else if let Ok(keys) = key.extract::>(py) { - self.select_columns(keys) - } else { - let message = "DataFrame can only be indexed by string index or indices"; - Err(PyTypeError::new_err(message)) - } - }) + /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]` + fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult { + if let Ok(key) = key.extract::() { + // df[col] + self.select_columns(vec![key]) + } else if let Ok(tuple) = key.extract::<&PyTuple>() { + // df[col1, col2, col3] + let keys = tuple + .iter() + .map(|item| item.extract::()) + .collect::>>()?; + self.select_columns(keys) + } else if let Ok(keys) = key.extract::>() { + // df[[col1, col2, col3]] + self.select_columns(keys) + } else { + let message = "DataFrame can only be indexed by string index or indices"; + Err(PyTypeError::new_err(message)) + } } fn __repr__(&self, py: Python) -> PyResult { @@ -98,7 +101,8 @@ impl PyDataFrame { } #[pyo3(signature = (*args))] - fn select_columns(&self, args: Vec<&str>) -> PyResult { + fn select_columns(&self, args: Vec) -> PyResult { + let args = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().select_columns(&args)?; Ok(Self::new(df)) } @@ -194,7 +198,7 @@ impl PyDataFrame { fn join( &self, right: PyDataFrame, - join_keys: (Vec<&str>, Vec<&str>), + join_keys: (Vec, Vec), how: &str, ) -> PyResult { let join_type = match how { @@ -212,11 +216,22 @@ impl PyDataFrame { } }; + let left_keys = join_keys + .0 + .iter() + .map(|s| s.as_ref()) + .collect::>(); + let right_keys = join_keys + .1 + .iter() + .map(|s| s.as_ref()) + .collect::>(); + let df = self.df.as_ref().clone().join( right.df.as_ref().clone(), join_type, - &join_keys.0, - &join_keys.1, + &left_keys, + &right_keys, None, )?; Ok(Self::new(df)) @@ -414,8 +429,8 @@ impl PyDataFrame { Python::with_gil(|py| { // Instantiate pyarrow Table object and use its from_batches method - let table_class = py.import("pyarrow")?.getattr("Table")?; - let args = PyTuple::new(py, &[batches, schema]); + let table_class = py.import_bound("pyarrow")?.getattr("Table")?; + let args = PyTuple::new_bound(py, &[batches, schema]); let table: PyObject = table_class.call_method1("from_batches", args)?.into(); Ok(table) }) @@ -489,8 +504,8 @@ impl PyDataFrame { let table = self.to_arrow_table(py)?; Python::with_gil(|py| { - let dataframe = py.import("polars")?.getattr("DataFrame")?; - let args = PyTuple::new(py, &[table]); + let dataframe = py.import_bound("polars")?.getattr("DataFrame")?; + let args = PyTuple::new_bound(py, &[table]); let result: PyObject = dataframe.call1(args)?.into(); Ok(result) }) @@ -514,7 +529,7 @@ fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> { // Import the Python 'builtins' module to access the print function // Note that println! does not print to the Python debug console and is not visible in notebooks for instance - let print = py.import("builtins")?.getattr("print")?; + let print = py.import_bound("builtins")?.getattr("print")?; print.call1((result,))?; Ok(()) } diff --git a/src/dataset.rs b/src/dataset.rs index fcbb503c0..724b4af76 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -46,13 +46,14 @@ pub(crate) struct Dataset { impl Dataset { // Creates a Python PyArrow.Dataset - pub fn new(dataset: &PyAny, py: Python) -> PyResult { + pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult { // Ensure that we were passed an instance of pyarrow.dataset.Dataset - let ds = PyModule::import(py, "pyarrow.dataset")?; - let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?; + let ds = PyModule::import_bound(py, "pyarrow.dataset")?; + let ds_attr = ds.getattr("Dataset")?; + let ds_type = ds_attr.downcast::()?; if dataset.is_instance(ds_type)? { Ok(Dataset { - dataset: dataset.into(), + dataset: dataset.clone().unbind(), }) } else { Err(PyValueError::new_err( @@ -73,7 +74,7 @@ impl TableProvider for Dataset { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef { Python::with_gil(|py| { - let dataset = self.dataset.as_ref(py); + let dataset = self.dataset.bind(py); // This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never Arc::new( dataset @@ -108,7 +109,7 @@ impl TableProvider for Dataset { ) -> DFResult> { Python::with_gil(|py| { let plan: Arc = Arc::new( - DatasetExec::new(py, self.dataset.as_ref(py), projection.cloned(), filters) + DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters) .map_err(|err| DataFusionError::External(Box::new(err)))?, ); Ok(plan) diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs index 8ef3a563e..240c86486 100644 --- a/src/dataset_exec.rs +++ b/src/dataset_exec.rs @@ -53,7 +53,7 @@ impl Iterator for PyArrowBatchesAdapter { fn next(&mut self) -> Option { Python::with_gil(|py| { - let mut batches: &PyIterator = self.batches.as_ref(py); + let mut batches = self.batches.clone().into_bound(py); Some( batches .next()? @@ -79,7 +79,7 @@ pub(crate) struct DatasetExec { impl DatasetExec { pub fn new( py: Python, - dataset: &PyAny, + dataset: &Bound<'_, PyAny>, projection: Option>, filters: &[Expr], ) -> Result { @@ -103,7 +103,7 @@ impl DatasetExec { }) .transpose()?; - let kwargs = PyDict::new(py); + let kwargs = PyDict::new_bound(py); kwargs.set_item("columns", columns.clone())?; kwargs.set_item( @@ -111,7 +111,7 @@ impl DatasetExec { filter_expr.as_ref().map(|expr| expr.clone_ref(py)), )?; - let scanner = dataset.call_method("scanner", (), Some(kwargs))?; + let scanner = dataset.call_method("scanner", (), Some(&kwargs))?; let schema = Arc::new( scanner @@ -120,19 +120,17 @@ impl DatasetExec { .0, ); - let builtins = Python::import(py, "builtins")?; + let builtins = Python::import_bound(py, "builtins")?; let pylist = builtins.getattr("list")?; // Get the fragments or partitions of the dataset - let fragments_iterator: &PyAny = dataset.call_method1( + let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1( "get_fragments", (filter_expr.as_ref().map(|expr| expr.clone_ref(py)),), )?; - let fragments: &PyList = pylist - .call1((fragments_iterator,))? - .downcast() - .map_err(PyErr::from)?; + let fragments_iter = pylist.call1((fragments_iterator,))?; + let fragments = fragments_iter.downcast::().map_err(PyErr::from)?; let projected_statistics = Statistics::new_unknown(&schema); let plan_properties = datafusion::physical_plan::PlanProperties::new( @@ -142,9 +140,9 @@ impl DatasetExec { ); Ok(DatasetExec { - dataset: dataset.into(), + dataset: dataset.clone().unbind(), schema, - fragments: fragments.into(), + fragments: fragments.clone().unbind(), columns, filter_expr, projected_statistics, @@ -183,8 +181,8 @@ impl ExecutionPlan for DatasetExec { ) -> DFResult { let batch_size = context.session_config().batch_size(); Python::with_gil(|py| { - let dataset = self.dataset.as_ref(py); - let fragments = self.fragments.as_ref(py); + let dataset = self.dataset.bind(py); + let fragments = self.fragments.bind(py); let fragment = fragments .get_item(partition) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; @@ -193,7 +191,7 @@ impl ExecutionPlan for DatasetExec { let dataset_schema = dataset .getattr("schema") .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; - let kwargs = PyDict::new(py); + let kwargs = PyDict::new_bound(py); kwargs .set_item("columns", self.columns.clone()) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; @@ -207,7 +205,7 @@ impl ExecutionPlan for DatasetExec { .set_item("batch_size", batch_size) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; let scanner = fragment - .call_method("scanner", (dataset_schema,), Some(kwargs)) + .call_method("scanner", (dataset_schema,), Some(&kwargs)) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; let schema: SchemaRef = Arc::new( scanner @@ -215,7 +213,7 @@ impl ExecutionPlan for DatasetExec { .and_then(|schema| Ok(schema.extract::>()?.0)) .map_err(|err| InnerDataFusionError::External(Box::new(err)))?, ); - let record_batches: &PyIterator = scanner + let record_batches: Bound<'_, PyIterator> = scanner .call_method0("to_batches") .map_err(|err| InnerDataFusionError::External(Box::new(err)))? .iter() @@ -264,7 +262,7 @@ impl ExecutionPlanProperties for DatasetExec { impl DisplayAs for DatasetExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { Python::with_gil(|py| { - let number_of_fragments = self.fragments.as_ref(py).len(); + let number_of_fragments = self.fragments.bind(py).len(); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let projected_columns: Vec = self @@ -274,7 +272,7 @@ impl DisplayAs for DatasetExec { .map(|x| x.name().to_owned()) .collect(); if let Some(filter_expr) = &self.filter_expr { - let filter_expr = filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?; + let filter_expr = filter_expr.bind(py).str().or(Err(std::fmt::Error))?; write!( f, "DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]", diff --git a/src/expr.rs b/src/expr.rs index 09a773c4d..dc1de669b 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -553,7 +553,7 @@ impl PyExpr { } /// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/functions.rs b/src/functions.rs index 09cdee619..8e395ae4f 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -670,7 +670,7 @@ aggregate_function!(bit_xor, BitXor); aggregate_function!(bool_and, BoolAnd); aggregate_function!(bool_or, BoolOr); -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(acosh))?; diff --git a/src/lib.rs b/src/lib.rs index a696ebff4..71c27e1ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,7 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// The higher-level public API is defined in pure python files under the /// datafusion directory. #[pymodule] -fn _internal(py: Python, m: &PyModule) -> PyResult<()> { +fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { // Register the Tokio Runtime as a module attribute so we can reuse it m.add( "runtime", @@ -94,35 +94,35 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ - let common = PyModule::new(py, "common")?; - common::init_module(common)?; - m.add_submodule(common)?; + let common = PyModule::new_bound(py, "common")?; + common::init_module(&common)?; + m.add_submodule(&common)?; // Register `expr` as a submodule. Matching `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ - let expr = PyModule::new(py, "expr")?; - expr::init_module(expr)?; - m.add_submodule(expr)?; + let expr = PyModule::new_bound(py, "expr")?; + expr::init_module(&expr)?; + m.add_submodule(&expr)?; // Register the functions as a submodule - let funcs = PyModule::new(py, "functions")?; - functions::init_module(funcs)?; - m.add_submodule(funcs)?; + let funcs = PyModule::new_bound(py, "functions")?; + functions::init_module(&funcs)?; + m.add_submodule(&funcs)?; - let store = PyModule::new(py, "object_store")?; - store::init_module(store)?; - m.add_submodule(store)?; + let store = PyModule::new_bound(py, "object_store")?; + store::init_module(&store)?; + m.add_submodule(&store)?; // Register substrait as a submodule #[cfg(feature = "substrait")] - setup_substrait_module(py, m)?; + setup_substrait_module(py, &m)?; Ok(()) } #[cfg(feature = "substrait")] -fn setup_substrait_module(py: Python, m: &PyModule) -> PyResult<()> { - let substrait = PyModule::new(py, "substrait")?; - substrait::init_module(substrait)?; - m.add_submodule(substrait)?; +fn setup_substrait_module(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + let substrait = PyModule::new_bound(py, "substrait")?; + substrait::init_module(&substrait)?; + m.add_submodule(&substrait)?; Ok(()) } diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 64124fb12..fca885121 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -32,9 +32,9 @@ pub(crate) struct PyArrowFilterExpression(PyObject); fn operator_to_py<'py>( operator: &Operator, - op: &'py PyModule, -) -> Result<&'py PyAny, DataFusionError> { - let py_op: &PyAny = match operator { + op: &Bound<'py, PyModule>, +) -> Result, DataFusionError> { + let py_op: Bound<'_, PyAny> = match operator { Operator::Eq => op.getattr("eq")?, Operator::NotEq => op.getattr("ne")?, Operator::Lt => op.getattr("lt")?, @@ -96,9 +96,9 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { // https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression fn try_from(expr: &Expr) -> Result { Python::with_gil(|py| { - let pc = Python::import(py, "pyarrow.compute")?; - let op_module = Python::import(py, "operator")?; - let pc_expr: Result<&PyAny, DataFusionError> = match expr { + let pc = Python::import_bound(py, "pyarrow.compute")?; + let op_module = Python::import_bound(py, "operator")?; + let pc_expr: Result, DataFusionError> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), Expr::Literal(v) => match v { ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?), @@ -118,7 +118,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { ))), }, Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let operator = operator_to_py(op, op_module)?; + let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; let right = PyArrowFilterExpression::try_from(right.as_ref())?.0; Ok(operator.call1((left, right))?) @@ -131,14 +131,15 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { Expr::IsNotNull(expr) => { let py_expr = PyArrowFilterExpression::try_from(expr.as_ref())? .0 - .into_ref(py); + .into_bound(py); Ok(py_expr.call_method0("is_valid")?) } Expr::IsNull(expr) => { let expr = PyArrowFilterExpression::try_from(expr.as_ref())? .0 - .into_ref(py); - Ok(expr.call_method1("is_null", (expr,))?) + .into_bound(py); + // TODO: this expression does not seems like it should be `call_method0` + Ok(expr.clone().call_method1("is_null", (expr,))?) } Expr::Between(Between { expr, @@ -168,7 +169,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { }) => { let expr = PyArrowFilterExpression::try_from(expr.as_ref())? .0 - .into_ref(py); + .into_bound(py); let scalars = extract_scalar_list(list, py)?; let ret = expr.call_method1("isin", (scalars,))?; let invert = op_module.getattr("invert")?; diff --git a/src/store.rs b/src/store.rs index 542cfa925..846d96a6d 100644 --- a/src/store.rs +++ b/src/store.rs @@ -219,7 +219,7 @@ impl PyAmazonS3Context { } } -pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/substrait.rs b/src/substrait.rs index ff83f6f79..1e9e16c7b 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -40,7 +40,7 @@ impl PyPlan { self.plan .encode(&mut proto_bytes) .map_err(DataFusionError::EncodeError)?; - Ok(PyBytes::new(py, &proto_bytes).into()) + Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into()) } } @@ -76,7 +76,7 @@ impl PySubstraitSerializer { pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { match PySubstraitSerializer::serialize_bytes(sql, ctx, py) { Ok(proto_bytes) => { - let proto_bytes: &PyBytes = proto_bytes.as_ref(py).downcast().unwrap(); + let proto_bytes = proto_bytes.bind(py).downcast::().unwrap(); PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py) } Err(e) => Err(py_datafusion_err(e)), @@ -87,7 +87,7 @@ impl PySubstraitSerializer { pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { let proto_bytes: Vec = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx)) .map_err(DataFusionError::from)?; - Ok(PyBytes::new(py, &proto_bytes).into()) + Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into()) } #[staticmethod] @@ -140,7 +140,7 @@ impl PySubstraitConsumer { } } -pub fn init_module(m: &PyModule) -> PyResult<()> { +pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/udaf.rs b/src/udaf.rs index 9aea761cd..7b5e03668 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use pyo3::{prelude::*, types::PyBool, types::PyTuple}; +use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{Array, ArrayRef}; use datafusion::arrow::datatypes::DataType; @@ -42,12 +42,12 @@ impl RustAccumulator { impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { - Python::with_gil(|py| self.accum.as_ref(py).call_method0("state")?.extract()) + Python::with_gil(|py| self.accum.bind(py).call_method0("state")?.extract()) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn evaluate(&mut self) -> Result { - Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) + Python::with_gil(|py| self.accum.bind(py).call_method0("evaluate")?.extract()) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } @@ -58,11 +58,11 @@ impl Accumulator for RustAccumulator { .iter() .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) .collect::>(); - let py_args = PyTuple::new(py, py_args); + let py_args = PyTuple::new_bound(py, py_args); // 2. call function self.accum - .as_ref(py) + .bind(py) .call_method1("update", py_args) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; @@ -82,7 +82,7 @@ impl Accumulator for RustAccumulator { // 2. call merge self.accum - .as_ref(py) + .bind(py) .call_method1("merge", (state,)) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; @@ -101,11 +101,11 @@ impl Accumulator for RustAccumulator { .iter() .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) .collect::>(); - let py_args = PyTuple::new(py, py_args); + let py_args = PyTuple::new_bound(py, py_args); // 2. call function self.accum - .as_ref(py) + .bind(py) .call_method1("retract_batch", py_args) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; @@ -114,12 +114,12 @@ impl Accumulator for RustAccumulator { } fn supports_retract_batch(&self) -> bool { - Python::with_gil(|py| { - let x: Result<&PyAny, PyErr> = - self.accum.as_ref(py).call_method0("supports_retract_batch"); - let x: &PyAny = x.unwrap_or(PyBool::new(py, false)); - x.extract().unwrap_or(false) - }) + Python::with_gil( + |py| match self.accum.bind(py).call_method0("supports_retract_batch") { + Ok(x) => x.extract().unwrap_or(false), + Err(_) => false, + }, + ) } } diff --git a/src/utils.rs b/src/utils.rs index 62cf07d9e..4334f86cd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -24,13 +24,13 @@ use tokio::runtime::Runtime; /// Utility to get the Tokio Runtime from Python pub(crate) fn get_tokio_runtime(py: Python) -> PyRef { - let datafusion = py.import("datafusion._internal").unwrap(); + let datafusion = py.import_bound("datafusion._internal").unwrap(); let tmp = datafusion.getattr("runtime").unwrap(); match tmp.extract::>() { Ok(runtime) => runtime, Err(_e) => { let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap()); - let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py); + let obj: Bound<'_, TokioRuntime> = Py::new(py, rt).unwrap().into_bound(py); obj.extract().unwrap() } }