Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit of updating pyo3 methods #21

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions datafusion/common/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@

//! Conversions between PyArrow and DataFusion types

// TODO update to pyo3 new APIs
// See: https://pyo3.rs/v0.23.0/migration
#![allow(deprecated)]

use arrow::array::ArrayData;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use arrow_array::Array;
use pyo3::exceptions::PyException;
use pyo3::prelude::PyErr;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python};
use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python};

use crate::{DataFusionError, ScalarValue};

Expand All @@ -44,8 +40,8 @@ impl FromPyArrow for ScalarValue {
let val = value.call_method0("as_py")?;

// construct pyarrow array from the python value and pyarrow type
let factory = py.import_bound("pyarrow")?.getattr("array")?;
let args = PyList::new_bound(py, [val]);
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [val])?;
let array = factory.call1((args, typ))?;

// convert the pyarrow array to rust array using C data interface
Expand Down Expand Up @@ -73,14 +69,25 @@ impl<'source> FromPyObject<'source> for ScalarValue {
}
}

impl IntoPy<PyObject> for ScalarValue {
fn into_py(self, py: Python) -> PyObject {
self.to_pyarrow(py).unwrap()
impl<'source> IntoPyObject<'source> for ScalarValue {
type Target = PyAny;

type Output = Bound<'source, Self::Target>;

type Error = PyErr;

fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
let array = self.to_array()?;
// convert to pyarrow array using C data interface
let pyarray = array.to_data().to_pyarrow(py)?;
let pyarray_bound = pyarray.bind(py);
pyarray_bound.call_method1("__getitem__", (0,))
}
}

#[cfg(test)]
mod tests {
use pyo3::ffi::c_str;
use pyo3::prepare_freethreaded_python;
use pyo3::py_run;
use pyo3::types::PyDict;
Expand All @@ -90,10 +97,12 @@ mod tests {
fn init_python() {
prepare_freethreaded_python();
Python::with_gil(|py| {
if py.run_bound("import pyarrow", None, None).is_err() {
let locals = PyDict::new_bound(py);
py.run_bound(
"import sys; executable = sys.executable; python_path = sys.path",
if py.run(c_str!("import pyarrow"), None, None).is_err() {
let locals = PyDict::new(py);
py.run(
c_str!(
"import sys; executable = sys.executable; python_path = sys.path"
),
None,
Some(&locals),
)
Expand Down Expand Up @@ -139,17 +148,25 @@ mod tests {
}

#[test]
fn test_py_scalar() {
fn test_py_scalar() -> PyResult<()> {
init_python();

Python::with_gil(|py| {
Python::with_gil(|py| -> PyResult<()> {
let scalar_float = ScalarValue::Float64(Some(12.34));
let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
let py_float = scalar_float
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_float, "assert py_float == 12.34");

let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
let py_string = scalar_string
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_string, "assert py_string == 'Hello!'");
});

Ok(())
})
}
}
Loading